root / elixir / trunk / elixir / entity.py @ 327

Revision 327, 42.4 kB (checked in by ged, 5 years ago)

- Fixed inheritance with autoloaded entities: when using autoload, we

shouldn't try to add columns to the table (closes tickets #41 and #43).

- Fixed ColumnProperty to work with latest version of SQLAlchemy (O.4.5 and

later)

- Added AUTHORS list. If you are missing from this list, don't hesitate to

contact me.

- Merged test_autoload_nopk into test_autoload

Line 
1'''
2This module provides the ``Entity`` base class, as well as its metaclass
3``EntityMeta``.
4'''
5
6from py23compat import set, rsplit, sorted
7
8import sys
9import warnings
10
11import sqlalchemy
12from sqlalchemy                    import Table, Column, Integer, \
13                                          desc, ForeignKey, and_, \
14                                          ForeignKeyConstraint
15from sqlalchemy.orm                import Query, MapperExtension, \
16                                          mapper, object_session, EXT_PASS, \
17                                          polymorphic_union
18from sqlalchemy.ext.sessioncontext import SessionContext
19
20import elixir
21from elixir.statements import process_mutators
22from elixir import options
23from elixir.properties import Property
24
25
26__doc_all__ = ['Entity', 'EntityMeta']
27
28
29try: 
30    from sqlalchemy.orm import ScopedSession
31except ImportError: 
32    # Not on sqlalchemy version 0.4
33    ScopedSession = type(None)
34
35
36def _do_mapping(session, cls, *args, **kwargs):
37    if session is None:
38        return mapper(cls, *args, **kwargs)
39    elif isinstance(session, ScopedSession):
40        return session.mapper(cls, *args, **kwargs)
41    elif isinstance(session, SessionContext):
42        extension = kwargs.pop('extension', None)
43        if extension is not None:
44            if not isinstance(extension, list):
45                extension = [extension]
46            extension.append(session.mapper_extension)
47        else:
48            extension = session.mapper_extension
49
50        class query(object):
51            def __getattr__(s, key):
52                return getattr(session.registry().query(cls), key)
53
54            def __call__(s):
55                return session.registry().query(cls)
56
57        if not 'query' in cls.__dict__: 
58            cls.query = query()
59
60        return mapper(cls, extension=extension, *args, **kwargs)
61    else:
62        raise Exception("Failed to map entity '%s' with its table or "
63                        "selectable" % cls.__name__)
64
65
66class EntityDescriptor(object):
67    '''
68    EntityDescriptor describes fields and options needed for table creation.
69    '''
70   
71    def __init__(self, entity):
72        entity.table = None
73        entity.mapper = None
74
75        self.entity = entity
76        self.module = sys.modules[entity.__module__]
77
78        self.has_pk = False
79        self._pk_col_done = False
80
81        self.builders = []
82
83        self.is_base = is_base(entity)
84        self.parent = None
85        self.children = []
86
87        for base in entity.__bases__:
88            if isinstance(base, EntityMeta) and not is_base(base):
89                if self.parent:
90                    raise Exception('%s entity inherits from several entities,'
91                                    ' and this is not supported.' 
92                                    % self.entity.__name__)
93                else:
94                    self.parent = base
95                    self.parent._descriptor.children.append(entity)
96
97        # columns and constraints waiting for a table to exist
98        self._columns = list()
99        self.constraints = list()
100        # properties waiting for a mapper to exist
101        self.properties = dict()
102
103        #
104        self.relationships = list()
105
106        # set default value for options
107        self.order_by = None
108        self.table_args = list()
109
110        # set default value for options with an optional module-level default
111        self.metadata = getattr(self.module, '__metadata__', elixir.metadata)
112        self.session = getattr(self.module, '__session__', elixir.session)
113        self.objectstore = None
114        self.collection = getattr(self.module, '__entity_collection__',
115                                  elixir.entities)
116
117        for option in ('autosetup', 'inheritance', 'polymorphic', 'identity',
118                       'autoload', 'tablename', 'shortnames', 
119                       'auto_primarykey', 'version_id_col', 
120                       'allowcoloverride'):
121            setattr(self, option, options.options_defaults[option])
122
123        for option_dict in ('mapper_options', 'table_options'):
124            setattr(self, option_dict, 
125                    options.options_defaults[option_dict].copy())
126
127    def setup_options(self):
128        '''
129        Setup any values that might depend on using_options. For example, the
130        tablename or the metadata.
131        '''
132        elixir.metadatas.add(self.metadata)
133        if self.collection is not None:
134            self.collection.append(self.entity)
135
136        objectstore = None
137        session = self.session
138        if session is None or isinstance(session, ScopedSession):
139            # no stinking objectstore
140            pass
141        elif isinstance(session, SessionContext):
142            objectstore = elixir.Objectstore(session)
143        elif not hasattr(session, 'registry'):
144            # Both SessionContext and ScopedSession have a registry attribute,
145            # but objectstores (whether Elixir's or Activemapper's) don't, so
146            # if we are here, it means an Objectstore is used for the session.
147#XXX: still true for activemapper post 0.4?           
148            objectstore = session
149            session = objectstore.context
150
151        self.session = session
152        self.objectstore = objectstore
153
154        entity = self.entity
155        if self.parent:
156            if self.inheritance == 'single':
157                self.tablename = self.parent._descriptor.tablename
158
159        if not self.tablename:
160            if self.shortnames:
161                self.tablename = entity.__name__.lower()
162            else:
163                modulename = entity.__module__.replace('.', '_')
164                tablename = "%s_%s" % (modulename, entity.__name__)
165                self.tablename = tablename.lower()
166        elif callable(self.tablename):
167            self.tablename = self.tablename(entity)
168
169        if not self.identity:
170            if 'polymorphic_identity' in self.mapper_options:
171                self.identity = self.mapper_options['polymorphic_identity']
172            else:
173                #TODO: include module name
174                self.identity = entity.__name__.lower()
175        elif 'polymorphic_identity' in kwargs:
176            raise Exception('You cannot use the "identity" option and the '
177                            'polymorphic_identity mapper option at the same '
178                            'time.')
179        elif callable(self.identity):
180            self.identity = self.identity(entity)
181
182        if self.polymorphic:
183            if not isinstance(self.polymorphic, basestring):
184                self.polymorphic = options.DEFAULT_POLYMORPHIC_COL_NAME
185
186    #---------------------
187    # setup phase methods
188
189    def setup_autoload_table(self):
190        self.setup_table(True)
191
192    def create_pk_cols(self):
193        """
194        Create primary_key columns. That is, call the 'create_pk_cols'
195        builders then add a primary key to the table if it hasn't already got
196        one and needs one.
197       
198        This method is "semi-recursive" in some cases: it calls the
199        create_keys method on ManyToOne relationships and those in turn call
200        create_pk_cols on their target. It shouldn't be possible to have an
201        infinite loop since a loop of primary_keys is not a valid situation.
202        """
203        if self._pk_col_done:
204            return
205
206        self.call_builders('create_pk_cols')
207
208        if not self.autoload:
209            if self.parent:
210                if self.inheritance == 'multi':
211                    # Add columns with foreign keys to the parent's primary
212                    # key columns
213                    parent_desc = self.parent._descriptor
214                    schema = parent_desc.table_options.get('schema', None)
215                    tablename = parent_desc.tablename 
216                    if schema is not None:
217                        tablename = "%s.%s" % (schema, tablename)
218                    for pk_col in parent_desc.primary_keys:
219                        colname = options.MULTIINHERITANCECOL_NAMEFORMAT % \
220                                  {'entity': self.parent.__name__.lower(),
221                                   'key': pk_col.key}
222
223                        # It seems like SA ForeignKey is not happy being given
224                        # a real column object when said column is not yet
225                        # attached to a table
226                        pk_col_name = "%s.%s" % (tablename, pk_col.key)
227                        fk = ForeignKey(pk_col_name, ondelete='cascade')
228                        col = Column(colname, pk_col.type, fk,
229                                     primary_key=True)
230                        self.add_column(col)
231                elif self.inheritance == 'concrete':
232                    # Copy primary key columns from the parent.
233                    for col in self.parent._descriptor.columns:
234                        if col.primary_key:
235                            self.add_column(col.copy())
236            elif not self.has_pk and self.auto_primarykey:
237                if isinstance(self.auto_primarykey, basestring):
238                    colname = self.auto_primarykey
239                else:
240                    colname = options.DEFAULT_AUTO_PRIMARYKEY_NAME
241
242                self.add_column(
243                    Column(colname, options.DEFAULT_AUTO_PRIMARYKEY_TYPE, 
244                           primary_key=True))
245        self._pk_col_done = True
246
247    def setup_relkeys(self):
248        self.call_builders('create_non_pk_cols')
249
250    def before_table(self):
251        self.call_builders('before_table')
252       
253    def setup_table(self, only_autoloaded=False):
254        '''
255        Create a SQLAlchemy table-object with all columns that have been
256        defined up to this point.
257        '''
258        if self.entity.table:
259            return
260
261        if self.autoload != only_autoloaded:
262            return
263
264        kwargs = self.table_options
265        if self.autoload:
266            args = self.table_args
267            kwargs['autoload'] = True
268        else:
269            if self.parent:
270                if self.inheritance == 'single':
271                    # we know the parent is setup before the child
272                    self.entity.table = self.parent.table 
273
274                    # re-add the entity columns to the parent entity so that they
275                    # are added to the parent's table (whether the parent's table
276                    # is already setup or not).
277                    for col in self.columns:
278                        self.parent._descriptor.add_column(col)
279                    for constraint in self.constraints:
280                        self.parent._descriptor.add_constraint(constraint)
281                    return
282                elif self.inheritance == 'concrete': 
283                    #TODO: we should also copy columns from the parent table
284                    # if the parent is a base (abstract?) entity (whatever the
285                    # inheritance type -> elif will need to be changed)
286
287                    # Copy all non-primary key columns from parent table
288                    # (primary key columns have already been copied earlier).
289                    for col in self.parent._descriptor.columns:
290                        if not col.primary_key:
291                            self.add_column(col.copy())
292
293                    #FIXME: use the public equivalent of _get_colspec when
294                    #available
295                    for con in self.parent._descriptor.constraints:
296                        self.add_constraint(
297                            ForeignKeyConstraint(
298                                [c.key for c in con.columns],
299                                [e._get_colspec() for e in con.elements],
300                                name=con.name, #TODO: modify it
301                                onupdate=con.onupdate, ondelete=con.ondelete,
302                                use_alter=con.use_alter))
303
304            if self.polymorphic and \
305               self.inheritance in ('single', 'multi') and \
306               self.children and not self.parent:
307                self.add_column(Column(self.polymorphic, 
308                                       options.POLYMORPHIC_COL_TYPE))
309
310            if self.version_id_col:
311                if not isinstance(self.version_id_col, basestring):
312                    self.version_id_col = options.DEFAULT_VERSION_ID_COL_NAME
313                self.add_column(Column(self.version_id_col, Integer))
314
315            args = self.columns + self.constraints + self.table_args
316       
317        self.entity.table = Table(self.tablename, self.metadata, 
318                                  *args, **kwargs)
319
320    def setup_reltables(self):
321        self.call_builders('create_tables')
322
323    def after_table(self):
324        self.call_builders('after_table')
325
326    def setup_events(self):
327        def make_proxy_method(methods):
328            def proxy_method(self, mapper, connection, instance):
329                for func in methods:
330                    ret = func(instance)
331                    # I couldn't commit myself to force people to
332                    # systematicaly return EXT_PASS in all their event methods.
333                    # But not doing that diverge to how SQLAlchemy works.
334                    # I should try to convince Mike to do EXT_PASS by default,
335                    # and stop processing as the special case.
336#                    if ret != EXT_PASS:
337                    if ret is not None and ret != EXT_PASS:
338                        return ret
339                return EXT_PASS
340            return proxy_method
341
342        # create a list of callbacks for each event
343        methods = {}
344        for name, method in self.entity.__dict__.items():
345            if hasattr(method, '_elixir_events'):
346                for event in method._elixir_events:
347                    event_methods = methods.setdefault(event, [])
348                    event_methods.append(method)
349        if not methods:
350            return
351       
352        # transform that list into methods themselves
353        for event in methods:
354            methods[event] = make_proxy_method(methods[event])
355       
356        # create a custom mapper extension class, tailored to our entity
357        ext = type('EventMapperExtension', (MapperExtension,), methods)()
358       
359        # then, make sure that the entity's mapper has our mapper extension
360        self.add_mapper_extension(ext)
361
362    def before_mapper(self):
363        self.call_builders('before_mapper')
364
365    def _get_children(self):
366        children = self.children[:]
367        for child in self.children:
368            children.extend(child._descriptor._get_children())
369        return children
370
371    def translate_order_by(self, order_by):
372        if isinstance(order_by, basestring):
373            order_by = [order_by]
374       
375        order = list()
376        for colname in order_by:
377            col = self.get_column(colname.strip('-'))
378            if colname.startswith('-'):
379                col = desc(col)
380            order.append(col)
381        return order
382
383    def setup_mapper(self):
384        '''
385        Initializes and assign an (empty!) mapper to the entity.
386        '''
387        if self.entity.mapper:
388            return
389
390        # for now we don't support the "abstract" parent class in a concrete
391        # inheritance scenario as demonstrated in
392        # sqlalchemy/test/orm/inheritance/concrete.py
393        # this should be added along other
394        kwargs = self.mapper_options
395        if self.order_by:
396            kwargs['order_by'] = self.translate_order_by(self.order_by)
397       
398        if self.version_id_col:
399            kwargs['version_id_col'] = self.get_column(self.version_id_col)
400
401        if self.inheritance in ('single', 'concrete', 'multi'):
402            if self.parent and \
403               not (self.inheritance == 'concrete' and not self.polymorphic):
404                kwargs['inherits'] = self.parent.mapper
405
406            if self.inheritance == 'multi' and self.parent:
407                col_pairs = zip(self.primary_keys,
408                                self.parent._descriptor.primary_keys)
409                kwargs['inherit_condition'] = \
410                    and_(*[pc == c for c, pc in col_pairs])
411
412            if self.polymorphic:
413                if self.children:
414                    if self.inheritance == 'concrete':
415                        keys = [(self.identity, self.entity.table)]
416                        keys.extend([(child._descriptor.identity, child.table) 
417                                     for child in self._get_children()])
418                        #XXX: we might need to change the alias name so that
419                        # children (which are parent themselves) don't end up
420                        # with the same alias than their parent?
421                        pjoin = polymorphic_union(
422                                    dict(keys), self.polymorphic, 'pjoin')
423
424                        kwargs['with_polymorphic'] = ('*', pjoin)
425                        kwargs['polymorphic_on'] = \
426                            getattr(pjoin.c, self.polymorphic)
427                    elif not self.parent:
428                        kwargs['polymorphic_on'] = \
429                            self.get_column(self.polymorphic)
430
431                    #TODO: this is an optimization, and it breaks the multi
432                    # table polymorphic inheritance test with a relation.
433                    # So I turn it off for now. We might want to provide an
434                    # option to turn it on.
435#                    if self.inheritance == 'multi':
436#                        children = self._get_children()
437#                        join = self.entity.table
438#                        for child in children:
439#                            join = join.outerjoin(child.table)
440#                        kwargs['select_table'] = join
441
442                if self.children or self.parent:
443                    kwargs['polymorphic_identity'] = self.identity
444
445                if self.parent and self.inheritance == 'concrete':
446                    kwargs['concrete'] = True
447
448        if 'primary_key' in kwargs:
449            cols = self.entity.table.c
450            kwargs['primary_key'] = [getattr(cols, colname) for
451                colname in kwargs['primary_key']]
452
453        if self.parent and self.inheritance == 'single':
454            args = []
455        else:
456            args = [self.entity.table]
457
458        self.entity.mapper = _do_mapping(self.session, self.entity,
459                                         properties=self.properties,
460                                         *args, **kwargs)
461
462    def after_mapper(self):
463        self.call_builders('after_mapper')
464
465    def setup_properties(self):
466        self.call_builders('create_properties')
467
468    def finalize(self):
469        self.call_builders('finalize')
470        self.entity._setup_done = True
471
472    #----------------
473    # helper methods
474
475    def call_builders(self, what):
476        for builder in self.builders:
477            if hasattr(builder, what):
478                getattr(builder, what)()
479
480    def add_column(self, col, check_duplicate=None):
481        '''when check_duplicate is None, the value of the allowcoloverride
482        option of the entity is used.
483        '''
484        if check_duplicate is None:
485            check_duplicate = not self.allowcoloverride
486       
487        if check_duplicate and self.get_column(col.key, False) is not None:
488            raise Exception("Column '%s' already exist in '%s' ! " % 
489                            (col.key, self.entity.__name__))
490        self._columns.append(col)
491       
492        if col.primary_key:
493            self.has_pk = True
494
495        # Autosetup triggers shouldn't be active anymore at this point, so we
496        # can theoretically access the entity's table safely. But the problem
497        # is that if, for some reason, the trigger removal phase didn't
498        # happen, we'll get an infinite loop. So we just make sure we don't
499        # get one in any case.
500        table = type.__getattribute__(self.entity, 'table')
501        if table:
502            if check_duplicate and col.key in table.columns.keys():
503                raise Exception("Column '%s' already exist in table '%s' ! " % 
504                                (col.key, table.name))
505            table.append_column(col)
506
507    def add_constraint(self, constraint):
508        self.constraints.append(constraint)
509       
510        table = self.entity.table
511        if table:
512            table.append_constraint(constraint)
513
514    def add_property(self, name, property, check_duplicate=True):
515        if check_duplicate and name in self.properties:
516            raise Exception("property '%s' already exist in '%s' ! " % 
517                            (name, self.entity.__name__))
518        self.properties[name] = property
519
520#FIXME: something like this is needed to propagate the relationships from
521# parent entities to their children in a concrete inheritance scenario. But
522# this doesn't work because of the backref matching code.
523#        if self.children and self.inheritance == 'concrete':
524#            for child in self.children:
525#                child._descriptor.add_property(name, property)
526
527        mapper = self.entity.mapper
528        if mapper:
529            mapper.add_property(name, property)
530       
531    def add_mapper_extension(self, extension):
532        extensions = self.mapper_options.get('extension', [])
533        if not isinstance(extensions, list):
534            extensions = [extensions]
535        extensions.append(extension)
536        self.mapper_options['extension'] = extensions
537
538    def get_column(self, key, check_missing=True):
539        "need to support both the case where the table is already setup or not"
540        #TODO: support SA table/autoloaded entity
541        for col in self.columns:
542            if col.key == key:
543                return col
544        if check_missing:
545            raise Exception("No column named '%s' found in the table of the "
546                            "'%s' entity!" % (key, self.entity.__name__))
547        return None
548
549    def get_inverse_relation(self, rel, reverse=False):
550        '''
551        Return the inverse relation of rel, if any, None otherwise.
552        '''
553
554        matching_rel = None
555        for other_rel in self.relationships:
556            if other_rel.is_inverse(rel):
557                if matching_rel is None:
558                    matching_rel = other_rel
559                else:
560                    raise Exception(
561                            "Several relations match as inverse of the '%s' "
562                            "relation in entity '%s'. You should specify "
563                            "inverse relations manually by using the inverse "
564                            "keyword."
565                            % (rel.name, rel.entity.__name__))
566        # When a matching inverse is found, we check that it has only
567        # one relation matching as its own inverse. We don't need the result
568        # of the method though. But we do need to be careful not to start an
569        # infinite recursive loop.
570        if matching_rel and not reverse:
571            rel.entity._descriptor.get_inverse_relation(matching_rel, True)
572
573        return matching_rel
574
575    def find_relationship(self, name):
576        for rel in self.relationships:
577            if rel.name == name:
578                return rel
579        if self.parent:
580            return self.parent._descriptor.find_relationship(name)
581        else:
582            return None
583
584    def columns(self):
585        #FIXME: this would be more correct but it breaks inheritance, so I'll
586        # use the old test for now.
587#        if self.entity.table:
588        if self.autoload: 
589            return self.entity.table.columns
590        else:
591            #FIXME: depending on the type of inheritance, we should also
592            # return the parent entity's columns (for example for order_by
593            # using a column defined in the parent.
594            return self._columns
595    columns = property(columns)
596
597    def primary_keys(self):
598        """
599        Returns the list of primary key columns of the entity.
600
601        This property isn't valid before the "create_pk_cols" phase.
602        """
603        if self.autoload:
604            return [col for col in self.entity.table.primary_key.columns]
605        else:
606            if self.parent and self.inheritance == 'single':
607                return self.parent._descriptor.primary_keys
608            else:
609                return [col for col in self.columns if col.primary_key]
610    primary_keys = property(primary_keys)
611
612
613class TriggerProxy(object):
614    """
615    A class that serves as a "trigger" ; accessing its attributes runs
616    the setup_all function.
617
618    Note that the `setup_all` is called on each access of the attribute.
619    """
620
621    def __init__(self, class_, attrname):
622        self.class_ = class_
623        self.attrname = attrname
624
625    def __getattr__(self, name):
626        elixir.setup_all()
627        #FIXME: it's possible to get an infinite loop here if setup_all doesn't
628        #remove the triggers for this entity. This can happen if the entity is
629        #not in the `entities` list for some reason.
630        proxied_attr = getattr(self.class_, self.attrname)
631        return getattr(proxied_attr, name)
632
633    def __repr__(self):
634        proxied_attr = getattr(self.class_, self.attrname)
635        return "<TriggerProxy (%s)>" % (self.class_.__name__)
636
637
638class TriggerAttribute(object):
639
640    def __init__(self, attrname):
641        self.attrname = attrname
642
643    def __get__(self, instance, owner):
644        #FIXME: it's possible to get an infinite loop here if setup_all doesn't
645        #remove the triggers for this entity. This can happen if the entity is
646        #not in the `entities` list for some reason.
647        elixir.setup_all()
648        return getattr(owner, self.attrname)
649
650def is_base(cls):
651    """
652    Scan bases classes to see if any is an instance of EntityMeta. If we
653    don't find any, it means the current entity is a base class (like
654    the 'Entity' class).
655    """
656    for base in cls.__bases__:
657        if isinstance(base, EntityMeta):
658            return False
659    return True
660
661class EntityMeta(type):
662    """
663    Entity meta class.
664    You should only use it directly if you want to define your own base class
665    for your entities (ie you don't want to use the provided 'Entity' class).
666    """
667    _entities = {}
668
669    def __init__(cls, name, bases, dict_):
670        # Only process further subclasses of the base classes (Entity et al.),
671        # not the base classes themselves. We don't want the base entities to
672        # be registered in an entity collection, nor to have a table name and
673        # so on.
674        if is_base(cls):
675            return
676
677        # build a dict of entities for each frame where there are entities
678        # defined
679        caller_frame = sys._getframe(1)
680        cid = cls._caller = id(caller_frame)
681        caller_entities = EntityMeta._entities.setdefault(cid, {})
682        caller_entities[name] = cls
683
684        # Append all entities which are currently visible by the entity. This
685        # will find more entities only if some of them where imported from
686        # another module.
687        for entity in [e for e in caller_frame.f_locals.values() 
688                         if isinstance(e, EntityMeta)]:
689            caller_entities[entity.__name__] = entity
690
691        # create the entity descriptor
692        desc = cls._descriptor = EntityDescriptor(cls)
693
694        # Process attributes (using the assignment syntax), looking for
695        # 'Property' instances and attaching them to this entity.
696        properties = [(name, attr) for name, attr in dict_.iteritems() 
697                                   if isinstance(attr, Property)]
698        sorted_props = sorted(properties, key=lambda i: i[1]._counter)
699
700        for name, prop in sorted_props:
701            prop.attach(cls, name)
702
703        # Process mutators. Needed before _install_autosetup_triggers so that
704        # we know of the metadata
705        process_mutators(cls)
706
707        # setup misc options here (like tablename etc.)
708        desc.setup_options()
709
710        # create trigger proxies
711        # TODO: support entity_name... It makes sense only for autoloaded
712        # tables for now, and would make more sense if we support "external"
713        # tables
714        if desc.autosetup:
715            _install_autosetup_triggers(cls)
716
717    def __call__(cls, *args, **kwargs):
718        if cls._descriptor.autosetup and not hasattr(cls, '_setup_done'):
719            elixir.setup_all()
720        return type.__call__(cls, *args, **kwargs)
721
722
723def _install_autosetup_triggers(cls, entity_name=None):
724    #TODO: move as much as possible of those "_private" values to the
725    # descriptor, so that we don't mess the initial class.
726    tablename = cls._descriptor.tablename
727    schema = cls._descriptor.table_options.get('schema', None)
728    cls._table_key = sqlalchemy.schema._get_table_key(tablename, schema)
729
730    table_proxy = TriggerProxy(cls, 'table')
731
732    md = cls._descriptor.metadata
733    md.tables[cls._table_key] = table_proxy
734
735    # We need to monkeypatch the metadata's table iterator method because
736    # otherwise it doesn't work if the setup is triggered by the
737    # metadata.create_all().
738    # This is because ManyToMany relationships add tables AFTER the list
739    # of tables that are going to be created is "computed"
740    # (metadata.tables.values()).
741    # see:
742    # - table_iterator method in MetaData class in sqlalchemy/schema.py
743    # - visit_metadata method in sqlalchemy/ansisql.py
744    original_table_iterator = md.table_iterator
745    if not hasattr(original_table_iterator, 
746                   '_non_elixir_patched_iterator'):
747        def table_iterator(*args, **kwargs):
748            elixir.setup_all()
749            return original_table_iterator(*args, **kwargs)
750        table_iterator.__doc__ = original_table_iterator.__doc__
751        table_iterator._non_elixir_patched_iterator = \
752            original_table_iterator
753        md.table_iterator = table_iterator
754
755    #TODO: we might want to add all columns that will be available as
756    #attributes on the class itself (in SA 0.4). This would be a pretty
757    #rare usecase, as people will hit the query attribute before the
758    #column attributes, but still...
759    for name in ('c', 'table', 'mapper', 'query'):
760        setattr(cls, name, TriggerAttribute(name))
761
762    cls._has_triggers = True
763
764
765def _cleanup_autosetup_triggers(cls):
766    if not hasattr(cls, '_has_triggers'):
767        return
768
769    for name in ('table', 'mapper'):
770        setattr(cls, name, None)
771
772    for name in ('c', 'query'):
773        delattr(cls, name)
774
775    desc = cls._descriptor
776    md = desc.metadata
777
778    # the fake table could have already been removed (namely in a
779    # single table inheritance scenario)
780    md.tables.pop(cls._table_key, None)
781
782    # restore original table iterator if not done already
783    if hasattr(md.table_iterator, '_non_elixir_patched_iterator'):
784        md.table_iterator = \
785            md.table_iterator._non_elixir_patched_iterator
786
787    del cls._has_triggers
788
789   
790def setup_entities(entities):
791    '''Setup all entities in the list passed as argument'''
792
793    for entity in entities:
794        if entity._descriptor.autosetup:
795            _cleanup_autosetup_triggers(entity)
796
797    for method_name in (
798            'setup_autoload_table', 'create_pk_cols', 'setup_relkeys',
799            'before_table', 'setup_table', 'setup_reltables', 'after_table',
800            'setup_events',
801            'before_mapper', 'setup_mapper', 'after_mapper',
802            'setup_properties',
803            'finalize'):
804        for entity in entities:
805            if hasattr(entity, '_setup_done'):
806                continue
807            method = getattr(entity._descriptor, method_name)
808            method()
809
810
811def cleanup_entities(entities):
812    """
813    Try to revert back the list of entities passed as argument to the state
814    they had just before their setup phase. It will not work entirely for
815    autosetup entities as we need to remove the autosetup triggers.
816
817    As of now, this function is *not* functional in that it doesn't revert to
818    the exact same state the entities were before setup. For example, the
819    properties do not work yet as those would need to be regenerated (since the
820    columns they are based on are regenerated too -- and as such the
821    corresponding joins are not correct) but this doesn't happen because of
822    the way relationship setup is designed to be called only once (especially
823    the backref stuff in create_properties).
824    """
825    for entity in entities:
826        desc = entity._descriptor
827        if desc.autosetup:
828            _cleanup_autosetup_triggers(entity)
829
830        if hasattr(entity, '_setup_done'):
831            del entity._setup_done
832
833        entity.table = None
834        entity.mapper = None
835       
836        desc._pk_col_done = False
837        desc.has_pk = False
838        desc._columns = []
839        desc.constraints = []
840        desc.properties = {}
841
842
843class Entity(object):
844    '''
845    The base class for all entities
846   
847    All Elixir model objects should inherit from this class. Statements can
848    appear within the body of the definition of an entity to define its
849    fields, relationships, and other options.
850   
851    Here is an example:
852
853    .. sourcecode:: python
854   
855        class Person(Entity):
856            name = Field(Unicode(128))
857            birthdate = Field(DateTime, default=datetime.now)
858   
859    Please note, that if you don't specify any primary keys, Elixir will
860    automatically create one called ``id``.
861   
862    For further information, please refer to the provided examples or
863    tutorial.
864    '''
865    __metaclass__ = EntityMeta
866   
867    def __init__(self, **kwargs):
868        for key, value in kwargs.items():
869            setattr(self, key, value)
870
871    def set(self, **kwargs):
872        self.from_dict(kwargs)
873
874    def from_dict(self, data):
875        """
876        Update a mapped class with data from a JSON-style nested dict/list
877        structure.
878        """
879        mapper = sqlalchemy.orm.object_mapper(self)
880        session = sqlalchemy.orm.object_session(self)
881        pkey = [c for c in mapper.mapped_table.columns if c.primary_key]
882
883        for col in mapper.mapped_table.c:
884            if not col.primary_key and data.has_key(col.name):
885                setattr(self, col.name, data[col.name])
886
887        for rel in mapper.iterate_properties:
888            rname = rel.key
889            if isinstance(rel, sqlalchemy.orm.properties.PropertyLoader) \
890                    and data.has_key(rname):
891                dbdata = getattr(self, rname)
892                if rel.uselist:
893                    # Build a lookup dict: {(pk1, pk2): value}
894                    lookup = dict([
895                        (tuple([getattr(o, c.name) for c in pkey]), o) 
896                        for o in dbdata])
897                    for row in data[rname]:
898                        # If any primary key columns are missing or None,
899                        # create a new object
900                        if [1 for c in pkey if not row.get(c.name)]:
901                            subobj = rel.mapper.class_()
902                            dbdata.append(subobj)
903                        else:
904                            key = tuple([row[c.name] for c in pkey])
905                            subobj = lookup.pop(key, None)
906
907                            # If the row isn't found, we must fail the request
908                            # in a web scenario, this could be a parameter
909                            # tampering attack
910                            if not subobj:
911                                raise sqlalchemy.exceptions.ArgumentError(
912                                        '%s row not found in database: %s' \
913                                        % (rname, repr(row)))
914                        subobj.from_dict(row)
915
916                    # Make sure the object list attribute doesn't contain any
917                    # old value (which are not present in the new data).
918                    for delobj in lookup.itervalues():
919                        dbdata.remove(delobj)
920                        session.delete(delobj)
921                else:
922                    if data[rname] is None:
923                        setattr(self, rname, None)
924                    else:
925                        if not dbdata:
926                            dbdata = rel.mapper.class_()
927                            setattr(self, rname, dbdata)
928                        dbdata.from_dict(data[rname])
929
930    def to_dict(self, deep={}, exclude=[]):
931        """Generate a JSON-style nested dict/list structure from an object."""
932        data = dict([(col.name, getattr(self, col.name))
933                     for col in self.table.c if col.name not in exclude])
934        for rname, rdeep in deep.iteritems():
935            dbdata = getattr(self, rname)
936            fks = self.mapper.get_property(rname).foreign_keys
937            exclude = [c.name for c in fks]
938            if isinstance(dbdata, list):
939                data[rname] = [o.to_dict(rdeep, exclude) for o in dbdata]
940            else:
941                data[rname] = dbdata.to_dict(rdeep, exclude)
942        return data
943
944    # session methods
945    def flush(self, *args, **kwargs):
946        return object_session(self).flush([self], *args, **kwargs)
947
948    def delete(self, *args, **kwargs):
949        return object_session(self).delete(self, *args, **kwargs)
950
951    def expire(self, *args, **kwargs):
952        return object_session(self).expire(self, *args, **kwargs)
953
954    def refresh(self, *args, **kwargs):
955        return object_session(self).refresh(self, *args, **kwargs)
956
957    def expunge(self, *args, **kwargs):
958        return object_session(self).expunge(self, *args, **kwargs)
959
960    # This bunch of session methods, along with all the query methods below
961    # only make sense when using a global/scoped/contextual session.
962    def _global_session(self):
963        return self._descriptor.session.registry()
964    _global_session = property(_global_session)
965
966    def merge(self, *args, **kwargs):
967        return self._global_session.merge(self, *args, **kwargs)
968
969    def save(self, *args, **kwargs):
970        return self._global_session.save(self, *args, **kwargs)
971
972    def update(self, *args, **kwargs):
973        return self._global_session.update(self, *args, **kwargs)
974
975    def save_or_update(self, *args, **kwargs):
976        return self._global_session.save_or_update(self, *args, **kwargs)
977
978    # query methods
979    def get_by(cls, *args, **kwargs):
980        return cls.query.filter_by(*args, **kwargs).first()
981    get_by = classmethod(get_by)
982
983    def get(cls, *args, **kwargs):
984        return cls.query.get(*args, **kwargs)
985    get = classmethod(get)
986
987    #-----------------#
988    # DEPRECATED LAND #
989    #-----------------#
990
991    def filter(cls, *args, **kwargs):
992        warnings.warn("The filter method on the class is deprecated."
993                      "You should use cls.query.filter(...)", 
994                      DeprecationWarning, stacklevel=2)
995        return cls.query.filter(*args, **kwargs)
996    filter = classmethod(filter)
997
998    def filter_by(cls, *args, **kwargs):
999        warnings.warn("The filter_by method on the class is deprecated."
1000                      "You should use cls.query.filter_by(...)", 
1001                      DeprecationWarning, stacklevel=2)
1002        return cls.query.filter_by(*args, **kwargs)
1003    filter_by = classmethod(filter_by)
1004
1005    def select(cls, *args, **kwargs):
1006        warnings.warn("The select method on the class is deprecated."
1007                      "You should use cls.query.filter(...).all()", 
1008                      DeprecationWarning, stacklevel=2)
1009        return cls.query.filter(*args, **kwargs).all()
1010    select = classmethod(select)
1011
1012    def select_by(cls, *args, **kwargs):
1013        warnings.warn("The select_by method on the class is deprecated."
1014                      "You should use cls.query.filter_by(...).all()", 
1015                      DeprecationWarning, stacklevel=2)
1016        return cls.query.filter_by(*args, **kwargs).all()
1017    select_by = classmethod(select_by)
1018
1019    def selectfirst(cls, *args, **kwargs):
1020        warnings.warn("The selectfirst method on the class is deprecated."
1021                      "You should use cls.query.filter(...).first()", 
1022                      DeprecationWarning, stacklevel=2)
1023        return cls.query.filter(*args, **kwargs).first()
1024    selectfirst = classmethod(selectfirst)
1025
1026    def selectfirst_by(cls, *args, **kwargs):
1027        warnings.warn("The selectfirst_by method on the class is deprecated."
1028                      "You should use cls.query.filter_by(...).first()", 
1029                      DeprecationWarning, stacklevel=2)
1030        return cls.query.filter_by(*args, **kwargs).first()
1031    selectfirst_by = classmethod(selectfirst_by)
1032
1033    def selectone(cls, *args, **kwargs):
1034        warnings.warn("The selectone method on the class is deprecated."
1035                      "You should use cls.query.filter(...).one()", 
1036                      DeprecationWarning, stacklevel=2)
1037        return cls.query.filter(*args, **kwargs).one()
1038    selectone = classmethod(selectone)
1039
1040    def selectone_by(cls, *args, **kwargs):
1041        warnings.warn("The selectone_by method on the class is deprecated."
1042                      "You should use cls.query.filter_by(...).one()", 
1043                      DeprecationWarning, stacklevel=2)
1044        return cls.query.filter_by(*args, **kwargs).one()
1045    selectone_by = classmethod(selectone_by)
1046
1047    def join_to(cls, *args, **kwargs):
1048        warnings.warn("The join_to method on the class is deprecated."
1049                      "You should use cls.query.join(...)", 
1050                      DeprecationWarning, stacklevel=2)
1051        return cls.query.join_to(*args, **kwargs).all()
1052    join_to = classmethod(join_to)
1053
1054    def join_via(cls, *args, **kwargs):
1055        warnings.warn("The join_via method on the class is deprecated."
1056                      "You should use cls.query.join(...)", 
1057                      DeprecationWarning, stacklevel=2)
1058        return cls.query.join_via(*args, **kwargs).all()
1059    join_via = classmethod(join_via)
1060
1061    def count(cls, *args, **kwargs):
1062        warnings.warn("The count method on the class is deprecated."
1063                      "You should use cls.query.filter(...).count()", 
1064                      DeprecationWarning, stacklevel=2)
1065        return cls.query.filter(*args, **kwargs).count()
1066    count = classmethod(count)
1067
1068    def count_by(cls, *args, **kwargs):
1069        warnings.warn("The count_by method on the class is deprecated."
1070                      "You should use cls.query.filter_by(...).count()", 
1071                      DeprecationWarning, stacklevel=2)
1072        return cls.query.filter_by(*args, **kwargs).count()
1073    count_by = classmethod(count_by)
1074
1075    def options(cls, *args, **kwargs):
1076        warnings.warn("The options method on the class is deprecated."
1077                      "You should use cls.query.options(...)", 
1078                      DeprecationWarning, stacklevel=2)
1079        return cls.query.options(*args, **kwargs)
1080    options = classmethod(options)
1081
1082    def instances(cls, *args, **kwargs):
1083        warnings.warn("The instances method on the class is deprecated."
1084                      "You should use cls.query.instances(...)", 
1085                      DeprecationWarning, stacklevel=2)
1086        return cls.query.instances(*args, **kwargs)
1087    instances = classmethod(instances)
1088
1089
Note: See TracBrowser for help on using the browser.