root / elixir / trunk / elixir / entity.py @ 335

Revision 335, 42.6 kB (checked in by cleverdevil, 5 years ago)

The to_dict method on Entity was not looking for columns in the case of multi-table inheritance. Now, the columns to be included in the to_dict are copied from the mapper.tables list.

Line 
1'''
2This module provides the ``Entity`` base class, as well as its metaclass
3``EntityMeta``.
4'''
5
6from py23compat import set, rsplit, sorted
7
8import sys
9import warnings
10
11import sqlalchemy
12from sqlalchemy                    import Table, Column, Integer, \
13                                          desc, ForeignKey, and_, \
14                                          ForeignKeyConstraint
15from sqlalchemy.orm                import Query, MapperExtension, \
16                                          mapper, object_session, EXT_PASS, \
17                                          polymorphic_union
18from sqlalchemy.ext.sessioncontext import SessionContext
19
20import elixir
21from elixir.statements import process_mutators
22from elixir import options
23from elixir.properties import Property
24
25
26__doc_all__ = ['Entity', 'EntityMeta']
27
28
29try: 
30    from sqlalchemy.orm import ScopedSession
31except ImportError: 
32    # Not on sqlalchemy version 0.4
33    ScopedSession = type(None)
34
35
36def _do_mapping(session, cls, *args, **kwargs):
37    if session is None:
38        return mapper(cls, *args, **kwargs)
39    elif isinstance(session, ScopedSession):
40        return session.mapper(cls, *args, **kwargs)
41    elif isinstance(session, SessionContext):
42        extension = kwargs.pop('extension', None)
43        if extension is not None:
44            if not isinstance(extension, list):
45                extension = [extension]
46            extension.append(session.mapper_extension)
47        else:
48            extension = session.mapper_extension
49
50        class query(object):
51            def __getattr__(s, key):
52                return getattr(session.registry().query(cls), key)
53
54            def __call__(s):
55                return session.registry().query(cls)
56
57        if not 'query' in cls.__dict__: 
58            cls.query = query()
59
60        return mapper(cls, extension=extension, *args, **kwargs)
61    else:
62        raise Exception("Failed to map entity '%s' with its table or "
63                        "selectable" % cls.__name__)
64
65
66class EntityDescriptor(object):
67    '''
68    EntityDescriptor describes fields and options needed for table creation.
69    '''
70   
71    def __init__(self, entity):
72        entity.table = None
73        entity.mapper = None
74
75        self.entity = entity
76        self.module = sys.modules[entity.__module__]
77
78        self.has_pk = False
79        self._pk_col_done = False
80
81        self.builders = []
82
83        self.is_base = is_base(entity)
84        self.parent = None
85        self.children = []
86
87        for base in entity.__bases__:
88            if isinstance(base, EntityMeta) and not is_base(base):
89                if self.parent:
90                    raise Exception('%s entity inherits from several entities,'
91                                    ' and this is not supported.' 
92                                    % self.entity.__name__)
93                else:
94                    self.parent = base
95                    self.parent._descriptor.children.append(entity)
96
97        # columns and constraints waiting for a table to exist
98        self._columns = list()
99        self.constraints = list()
100        # properties waiting for a mapper to exist
101        self.properties = dict()
102
103        #
104        self.relationships = list()
105
106        # set default value for options
107        self.order_by = None
108        self.table_args = list()
109
110        # set default value for options with an optional module-level default
111        self.metadata = getattr(self.module, '__metadata__', elixir.metadata)
112        self.session = getattr(self.module, '__session__', elixir.session)
113        self.objectstore = None
114        self.collection = getattr(self.module, '__entity_collection__',
115                                  elixir.entities)
116
117        for option in ('autosetup', 'inheritance', 'polymorphic', 'identity',
118                       'autoload', 'tablename', 'shortnames', 
119                       'auto_primarykey', 'version_id_col', 
120                       'allowcoloverride'):
121            setattr(self, option, options.options_defaults[option])
122
123        for option_dict in ('mapper_options', 'table_options'):
124            setattr(self, option_dict, 
125                    options.options_defaults[option_dict].copy())
126
127    def setup_options(self):
128        '''
129        Setup any values that might depend on using_options. For example, the
130        tablename or the metadata.
131        '''
132        elixir.metadatas.add(self.metadata)
133        if self.collection is not None:
134            self.collection.append(self.entity)
135            self.collection.map_entity(self.entity, self.entity.__name__)
136
137        objectstore = None
138        session = self.session
139        if session is None or isinstance(session, ScopedSession):
140            # no stinking objectstore
141            pass
142        elif isinstance(session, SessionContext):
143            objectstore = elixir.Objectstore(session)
144        elif not hasattr(session, 'registry'):
145            # Both SessionContext and ScopedSession have a registry attribute,
146            # but objectstores (whether Elixir's or Activemapper's) don't, so
147            # if we are here, it means an Objectstore is used for the session.
148#XXX: still true for activemapper post 0.4?           
149            objectstore = session
150            session = objectstore.context
151
152        self.session = session
153        self.objectstore = objectstore
154
155        entity = self.entity
156        if self.parent:
157            if self.inheritance == 'single':
158                self.tablename = self.parent._descriptor.tablename
159
160        if not self.tablename:
161            if self.shortnames:
162                self.tablename = entity.__name__.lower()
163            else:
164                modulename = entity.__module__.replace('.', '_')
165                tablename = "%s_%s" % (modulename, entity.__name__)
166                self.tablename = tablename.lower()
167        elif callable(self.tablename):
168            self.tablename = self.tablename(entity)
169
170        if not self.identity:
171            if 'polymorphic_identity' in self.mapper_options:
172                self.identity = self.mapper_options['polymorphic_identity']
173            else:
174                #TODO: include module name
175                self.identity = entity.__name__.lower()
176        elif 'polymorphic_identity' in kwargs:
177            raise Exception('You cannot use the "identity" option and the '
178                            'polymorphic_identity mapper option at the same '
179                            'time.')
180        elif callable(self.identity):
181            self.identity = self.identity(entity)
182
183        if self.polymorphic:
184            if not isinstance(self.polymorphic, basestring):
185                self.polymorphic = options.DEFAULT_POLYMORPHIC_COL_NAME
186
187    #---------------------
188    # setup phase methods
189
190    def setup_autoload_table(self):
191        self.setup_table(True)
192
193    def create_pk_cols(self):
194        """
195        Create primary_key columns. That is, call the 'create_pk_cols'
196        builders then add a primary key to the table if it hasn't already got
197        one and needs one.
198       
199        This method is "semi-recursive" in some cases: it calls the
200        create_keys method on ManyToOne relationships and those in turn call
201        create_pk_cols on their target. It shouldn't be possible to have an
202        infinite loop since a loop of primary_keys is not a valid situation.
203        """
204        if self._pk_col_done:
205            return
206
207        self.call_builders('create_pk_cols')
208
209        if not self.autoload:
210            if self.parent:
211                if self.inheritance == 'multi':
212                    # Add columns with foreign keys to the parent's primary
213                    # key columns
214                    parent_desc = self.parent._descriptor
215                    schema = parent_desc.table_options.get('schema', None)
216                    tablename = parent_desc.tablename 
217                    if schema is not None:
218                        tablename = "%s.%s" % (schema, tablename)
219                    for pk_col in parent_desc.primary_keys:
220                        colname = options.MULTIINHERITANCECOL_NAMEFORMAT % \
221                                  {'entity': self.parent.__name__.lower(),
222                                   'key': pk_col.key}
223
224                        # It seems like SA ForeignKey is not happy being given
225                        # a real column object when said column is not yet
226                        # attached to a table
227                        pk_col_name = "%s.%s" % (tablename, pk_col.key)
228                        fk = ForeignKey(pk_col_name, ondelete='cascade')
229                        col = Column(colname, pk_col.type, fk,
230                                     primary_key=True)
231                        self.add_column(col)
232                elif self.inheritance == 'concrete':
233                    # Copy primary key columns from the parent.
234                    for col in self.parent._descriptor.columns:
235                        if col.primary_key:
236                            self.add_column(col.copy())
237            elif not self.has_pk and self.auto_primarykey:
238                if isinstance(self.auto_primarykey, basestring):
239                    colname = self.auto_primarykey
240                else:
241                    colname = options.DEFAULT_AUTO_PRIMARYKEY_NAME
242
243                self.add_column(
244                    Column(colname, options.DEFAULT_AUTO_PRIMARYKEY_TYPE, 
245                           primary_key=True))
246        self._pk_col_done = True
247
248    def setup_relkeys(self):
249        self.call_builders('create_non_pk_cols')
250
251    def before_table(self):
252        self.call_builders('before_table')
253       
254    def setup_table(self, only_autoloaded=False):
255        '''
256        Create a SQLAlchemy table-object with all columns that have been
257        defined up to this point.
258        '''
259        if self.entity.table:
260            return
261
262        if self.autoload != only_autoloaded:
263            return
264
265        kwargs = self.table_options
266        if self.autoload:
267            args = self.table_args
268            kwargs['autoload'] = True
269        else:
270            if self.parent:
271                if self.inheritance == 'single':
272                    # we know the parent is setup before the child
273                    self.entity.table = self.parent.table 
274
275                    # re-add the entity columns to the parent entity so that they
276                    # are added to the parent's table (whether the parent's table
277                    # is already setup or not).
278                    for col in self.columns:
279                        self.parent._descriptor.add_column(col)
280                    for constraint in self.constraints:
281                        self.parent._descriptor.add_constraint(constraint)
282                    return
283                elif self.inheritance == 'concrete': 
284                    #TODO: we should also copy columns from the parent table
285                    # if the parent is a base (abstract?) entity (whatever the
286                    # inheritance type -> elif will need to be changed)
287
288                    # Copy all non-primary key columns from parent table
289                    # (primary key columns have already been copied earlier).
290                    for col in self.parent._descriptor.columns:
291                        if not col.primary_key:
292                            self.add_column(col.copy())
293
294                    #FIXME: use the public equivalent of _get_colspec when
295                    #available
296                    for con in self.parent._descriptor.constraints:
297                        self.add_constraint(
298                            ForeignKeyConstraint(
299                                [c.key for c in con.columns],
300                                [e._get_colspec() for e in con.elements],
301                                name=con.name, #TODO: modify it
302                                onupdate=con.onupdate, ondelete=con.ondelete,
303                                use_alter=con.use_alter))
304
305            if self.polymorphic and \
306               self.inheritance in ('single', 'multi') and \
307               self.children and not self.parent:
308                self.add_column(Column(self.polymorphic, 
309                                       options.POLYMORPHIC_COL_TYPE))
310
311            if self.version_id_col:
312                if not isinstance(self.version_id_col, basestring):
313                    self.version_id_col = options.DEFAULT_VERSION_ID_COL_NAME
314                self.add_column(Column(self.version_id_col, Integer))
315
316            args = self.columns + self.constraints + self.table_args
317       
318        self.entity.table = Table(self.tablename, self.metadata, 
319                                  *args, **kwargs)
320
321    def setup_reltables(self):
322        self.call_builders('create_tables')
323
324    def after_table(self):
325        self.call_builders('after_table')
326
327    def setup_events(self):
328        def make_proxy_method(methods):
329            def proxy_method(self, mapper, connection, instance):
330                for func in methods:
331                    ret = func(instance)
332                    # I couldn't commit myself to force people to
333                    # systematicaly return EXT_PASS in all their event methods.
334                    # But not doing that diverge to how SQLAlchemy works.
335                    # I should try to convince Mike to do EXT_PASS by default,
336                    # and stop processing as the special case.
337#                    if ret != EXT_PASS:
338                    if ret is not None and ret != EXT_PASS:
339                        return ret
340                return EXT_PASS
341            return proxy_method
342
343        # create a list of callbacks for each event
344        methods = {}
345        for name, method in self.entity.__dict__.items():
346            if hasattr(method, '_elixir_events'):
347                for event in method._elixir_events:
348                    event_methods = methods.setdefault(event, [])
349                    event_methods.append(method)
350        if not methods:
351            return
352       
353        # transform that list into methods themselves
354        for event in methods:
355            methods[event] = make_proxy_method(methods[event])
356       
357        # create a custom mapper extension class, tailored to our entity
358        ext = type('EventMapperExtension', (MapperExtension,), methods)()
359       
360        # then, make sure that the entity's mapper has our mapper extension
361        self.add_mapper_extension(ext)
362
363    def before_mapper(self):
364        self.call_builders('before_mapper')
365
366    def _get_children(self):
367        children = self.children[:]
368        for child in self.children:
369            children.extend(child._descriptor._get_children())
370        return children
371
372    def translate_order_by(self, order_by):
373        if isinstance(order_by, basestring):
374            order_by = [order_by]
375       
376        order = list()
377        for colname in order_by:
378            col = self.get_column(colname.strip('-'))
379            if colname.startswith('-'):
380                col = desc(col)
381            order.append(col)
382        return order
383
384    def setup_mapper(self):
385        '''
386        Initializes and assign an (empty!) mapper to the entity.
387        '''
388        if self.entity.mapper:
389            return
390
391        # for now we don't support the "abstract" parent class in a concrete
392        # inheritance scenario as demonstrated in
393        # sqlalchemy/test/orm/inheritance/concrete.py
394        # this should be added along other
395        kwargs = self.mapper_options
396        if self.order_by:
397            kwargs['order_by'] = self.translate_order_by(self.order_by)
398       
399        if self.version_id_col:
400            kwargs['version_id_col'] = self.get_column(self.version_id_col)
401
402        if self.inheritance in ('single', 'concrete', 'multi'):
403            if self.parent and \
404               not (self.inheritance == 'concrete' and not self.polymorphic):
405                kwargs['inherits'] = self.parent.mapper
406
407            if self.inheritance == 'multi' and self.parent:
408                col_pairs = zip(self.primary_keys,
409                                self.parent._descriptor.primary_keys)
410                kwargs['inherit_condition'] = \
411                    and_(*[pc == c for c, pc in col_pairs])
412
413            if self.polymorphic:
414                if self.children:
415                    if self.inheritance == 'concrete':
416                        keys = [(self.identity, self.entity.table)]
417                        keys.extend([(child._descriptor.identity, child.table) 
418                                     for child in self._get_children()])
419                        #XXX: we might need to change the alias name so that
420                        # children (which are parent themselves) don't end up
421                        # with the same alias than their parent?
422                        pjoin = polymorphic_union(
423                                    dict(keys), self.polymorphic, 'pjoin')
424
425                        kwargs['with_polymorphic'] = ('*', pjoin)
426                        kwargs['polymorphic_on'] = \
427                            getattr(pjoin.c, self.polymorphic)
428                    elif not self.parent:
429                        kwargs['polymorphic_on'] = \
430                            self.get_column(self.polymorphic)
431
432                    #TODO: this is an optimization, and it breaks the multi
433                    # table polymorphic inheritance test with a relation.
434                    # So I turn it off for now. We might want to provide an
435                    # option to turn it on.
436#                    if self.inheritance == 'multi':
437#                        children = self._get_children()
438#                        join = self.entity.table
439#                        for child in children:
440#                            join = join.outerjoin(child.table)
441#                        kwargs['select_table'] = join
442
443                if self.children or self.parent:
444                    kwargs['polymorphic_identity'] = self.identity
445
446                if self.parent and self.inheritance == 'concrete':
447                    kwargs['concrete'] = True
448
449        if 'primary_key' in kwargs:
450            cols = self.entity.table.c
451            kwargs['primary_key'] = [getattr(cols, colname) for
452                colname in kwargs['primary_key']]
453
454        if self.parent and self.inheritance == 'single':
455            args = []
456        else:
457            args = [self.entity.table]
458
459        self.entity.mapper = _do_mapping(self.session, self.entity,
460                                         properties=self.properties,
461                                         *args, **kwargs)
462
463    def after_mapper(self):
464        self.call_builders('after_mapper')
465
466    def setup_properties(self):
467        self.call_builders('create_properties')
468
469    def finalize(self):
470        self.call_builders('finalize')
471        self.entity._setup_done = True
472
473    #----------------
474    # helper methods
475
476    def call_builders(self, what):
477        for builder in self.builders:
478            if hasattr(builder, what):
479                getattr(builder, what)()
480
481    def add_column(self, col, check_duplicate=None):
482        '''when check_duplicate is None, the value of the allowcoloverride
483        option of the entity is used.
484        '''
485        if check_duplicate is None:
486            check_duplicate = not self.allowcoloverride
487       
488        if check_duplicate and self.get_column(col.key, False) is not None:
489            raise Exception("Column '%s' already exist in '%s' ! " % 
490                            (col.key, self.entity.__name__))
491        self._columns.append(col)
492       
493        if col.primary_key:
494            self.has_pk = True
495
496        # Autosetup triggers shouldn't be active anymore at this point, so we
497        # can theoretically access the entity's table safely. But the problem
498        # is that if, for some reason, the trigger removal phase didn't
499        # happen, we'll get an infinite loop. So we just make sure we don't
500        # get one in any case.
501        table = type.__getattribute__(self.entity, 'table')
502        if table:
503            if check_duplicate and col.key in table.columns.keys():
504                raise Exception("Column '%s' already exist in table '%s' ! " % 
505                                (col.key, table.name))
506            table.append_column(col)
507
508    def add_constraint(self, constraint):
509        self.constraints.append(constraint)
510       
511        table = self.entity.table
512        if table:
513            table.append_constraint(constraint)
514
515    def add_property(self, name, property, check_duplicate=True):
516        if check_duplicate and name in self.properties:
517            raise Exception("property '%s' already exist in '%s' ! " % 
518                            (name, self.entity.__name__))
519        self.properties[name] = property
520
521#FIXME: something like this is needed to propagate the relationships from
522# parent entities to their children in a concrete inheritance scenario. But
523# this doesn't work because of the backref matching code.
524#        if self.children and self.inheritance == 'concrete':
525#            for child in self.children:
526#                child._descriptor.add_property(name, property)
527
528        mapper = self.entity.mapper
529        if mapper:
530            mapper.add_property(name, property)
531       
532    def add_mapper_extension(self, extension):
533        extensions = self.mapper_options.get('extension', [])
534        if not isinstance(extensions, list):
535            extensions = [extensions]
536        extensions.append(extension)
537        self.mapper_options['extension'] = extensions
538
539    def get_column(self, key, check_missing=True):
540        "need to support both the case where the table is already setup or not"
541        #TODO: support SA table/autoloaded entity
542        for col in self.columns:
543            if col.key == key:
544                return col
545        if check_missing:
546            raise Exception("No column named '%s' found in the table of the "
547                            "'%s' entity!" % (key, self.entity.__name__))
548        return None
549
550    def get_inverse_relation(self, rel, reverse=False):
551        '''
552        Return the inverse relation of rel, if any, None otherwise.
553        '''
554
555        matching_rel = None
556        for other_rel in self.relationships:
557            if other_rel.is_inverse(rel):
558                if matching_rel is None:
559                    matching_rel = other_rel
560                else:
561                    raise Exception(
562                            "Several relations match as inverse of the '%s' "
563                            "relation in entity '%s'. You should specify "
564                            "inverse relations manually by using the inverse "
565                            "keyword."
566                            % (rel.name, rel.entity.__name__))
567        # When a matching inverse is found, we check that it has only
568        # one relation matching as its own inverse. We don't need the result
569        # of the method though. But we do need to be careful not to start an
570        # infinite recursive loop.
571        if matching_rel and not reverse:
572            rel.entity._descriptor.get_inverse_relation(matching_rel, True)
573
574        return matching_rel
575
576    def find_relationship(self, name):
577        for rel in self.relationships:
578            if rel.name == name:
579                return rel
580        if self.parent:
581            return self.parent._descriptor.find_relationship(name)
582        else:
583            return None
584
585    def columns(self):
586        #FIXME: this would be more correct but it breaks inheritance, so I'll
587        # use the old test for now.
588#        if self.entity.table:
589        if self.autoload: 
590            return self.entity.table.columns
591        else:
592            #FIXME: depending on the type of inheritance, we should also
593            # return the parent entity's columns (for example for order_by
594            # using a column defined in the parent.
595            return self._columns
596    columns = property(columns)
597
598    def primary_keys(self):
599        """
600        Returns the list of primary key columns of the entity.
601
602        This property isn't valid before the "create_pk_cols" phase.
603        """
604        if self.autoload:
605            return [col for col in self.entity.table.primary_key.columns]
606        else:
607            if self.parent and self.inheritance == 'single':
608                return self.parent._descriptor.primary_keys
609            else:
610                return [col for col in self.columns if col.primary_key]
611    primary_keys = property(primary_keys)
612
613
614class TriggerProxy(object):
615    """
616    A class that serves as a "trigger" ; accessing its attributes runs
617    the setup_all function.
618
619    Note that the `setup_all` is called on each access of the attribute.
620    """
621
622    def __init__(self, class_, attrname):
623        self.class_ = class_
624        self.attrname = attrname
625
626    def __getattr__(self, name):
627        elixir.setup_all()
628        #FIXME: it's possible to get an infinite loop here if setup_all doesn't
629        #remove the triggers for this entity. This can happen if the entity is
630        #not in the `entities` list for some reason.
631        proxied_attr = getattr(self.class_, self.attrname)
632        return getattr(proxied_attr, name)
633
634    def __repr__(self):
635        proxied_attr = getattr(self.class_, self.attrname)
636        return "<TriggerProxy (%s)>" % (self.class_.__name__)
637
638
639class TriggerAttribute(object):
640
641    def __init__(self, attrname):
642        self.attrname = attrname
643
644    def __get__(self, instance, owner):
645        #FIXME: it's possible to get an infinite loop here if setup_all doesn't
646        #remove the triggers for this entity. This can happen if the entity is
647        #not in the `entities` list for some reason.
648        elixir.setup_all()
649        return getattr(owner, self.attrname)
650
651def is_base(cls):
652    """
653    Scan bases classes to see if any is an instance of EntityMeta. If we
654    don't find any, it means the current entity is a base class (like
655    the 'Entity' class).
656    """
657    for base in cls.__bases__:
658        if isinstance(base, EntityMeta):
659            return False
660    return True
661
662class EntityMeta(type):
663    """
664    Entity meta class.
665    You should only use it directly if you want to define your own base class
666    for your entities (ie you don't want to use the provided 'Entity' class).
667    """
668    _entities = {}
669
670    def __init__(cls, name, bases, dict_):
671        # Only process further subclasses of the base classes (Entity et al.),
672        # not the base classes themselves. We don't want the base entities to
673        # be registered in an entity collection, nor to have a table name and
674        # so on.
675        if is_base(cls):
676            return
677
678        # build a dict of entities for each frame where there are entities
679        # defined
680        caller_frame = sys._getframe(1)
681        cid = cls._caller = id(caller_frame)
682        caller_entities = EntityMeta._entities.setdefault(cid, {})
683        caller_entities[name] = cls
684
685        # Append all entities which are currently visible by the entity. This
686        # will find more entities only if some of them where imported from
687        # another module.
688        for entity in [e for e in caller_frame.f_locals.values() 
689                         if isinstance(e, EntityMeta)]:
690            caller_entities[entity.__name__] = entity
691
692        # create the entity descriptor
693        desc = cls._descriptor = EntityDescriptor(cls)
694
695        # Process attributes (using the assignment syntax), looking for
696        # 'Property' instances and attaching them to this entity.
697        properties = [(name, attr) for name, attr in dict_.iteritems() 
698                                   if isinstance(attr, Property)]
699        sorted_props = sorted(properties, key=lambda i: i[1]._counter)
700
701        for name, prop in sorted_props:
702            prop.attach(cls, name)
703
704        # Process mutators. Needed before _install_autosetup_triggers so that
705        # we know of the metadata
706        process_mutators(cls)
707
708        # setup misc options here (like tablename etc.)
709        desc.setup_options()
710
711        # create trigger proxies
712        # TODO: support entity_name... It makes sense only for autoloaded
713        # tables for now, and would make more sense if we support "external"
714        # tables
715        if desc.autosetup:
716            _install_autosetup_triggers(cls)
717
718    def __call__(cls, *args, **kwargs):
719        if cls._descriptor.autosetup and not hasattr(cls, '_setup_done'):
720            elixir.setup_all()
721        return type.__call__(cls, *args, **kwargs)
722
723
724def _install_autosetup_triggers(cls, entity_name=None):
725    #TODO: move as much as possible of those "_private" values to the
726    # descriptor, so that we don't mess the initial class.
727    tablename = cls._descriptor.tablename
728    schema = cls._descriptor.table_options.get('schema', None)
729    cls._table_key = sqlalchemy.schema._get_table_key(tablename, schema)
730
731    table_proxy = TriggerProxy(cls, 'table')
732
733    md = cls._descriptor.metadata
734    md.tables[cls._table_key] = table_proxy
735
736    # We need to monkeypatch the metadata's table iterator method because
737    # otherwise it doesn't work if the setup is triggered by the
738    # metadata.create_all().
739    # This is because ManyToMany relationships add tables AFTER the list
740    # of tables that are going to be created is "computed"
741    # (metadata.tables.values()).
742    # see:
743    # - table_iterator method in MetaData class in sqlalchemy/schema.py
744    # - visit_metadata method in sqlalchemy/ansisql.py
745    original_table_iterator = md.table_iterator
746    if not hasattr(original_table_iterator, 
747                   '_non_elixir_patched_iterator'):
748        def table_iterator(*args, **kwargs):
749            elixir.setup_all()
750            return original_table_iterator(*args, **kwargs)
751        table_iterator.__doc__ = original_table_iterator.__doc__
752        table_iterator._non_elixir_patched_iterator = \
753            original_table_iterator
754        md.table_iterator = table_iterator
755
756    #TODO: we might want to add all columns that will be available as
757    #attributes on the class itself (in SA 0.4). This would be a pretty
758    #rare usecase, as people will hit the query attribute before the
759    #column attributes, but still...
760    for name in ('c', 'table', 'mapper', 'query'):
761        setattr(cls, name, TriggerAttribute(name))
762
763    cls._has_triggers = True
764
765
766def _cleanup_autosetup_triggers(cls):
767    if not hasattr(cls, '_has_triggers'):
768        return
769
770    for name in ('table', 'mapper'):
771        setattr(cls, name, None)
772
773    for name in ('c', 'query'):
774        delattr(cls, name)
775
776    desc = cls._descriptor
777    md = desc.metadata
778
779    # the fake table could have already been removed (namely in a
780    # single table inheritance scenario)
781    md.tables.pop(cls._table_key, None)
782
783    # restore original table iterator if not done already
784    if hasattr(md.table_iterator, '_non_elixir_patched_iterator'):
785        md.table_iterator = \
786            md.table_iterator._non_elixir_patched_iterator
787
788    del cls._has_triggers
789
790   
791def setup_entities(entities):
792    '''Setup all entities in the list passed as argument'''
793
794    for entity in entities:
795        if entity._descriptor.autosetup:
796            _cleanup_autosetup_triggers(entity)
797
798    for method_name in (
799            'setup_autoload_table', 'create_pk_cols', 'setup_relkeys',
800            'before_table', 'setup_table', 'setup_reltables', 'after_table',
801            'setup_events',
802            'before_mapper', 'setup_mapper', 'after_mapper',
803            'setup_properties',
804            'finalize'):
805        for entity in entities:
806            if hasattr(entity, '_setup_done'):
807                continue
808            method = getattr(entity._descriptor, method_name)
809            method()
810
811
812def cleanup_entities(entities):
813    """
814    Try to revert back the list of entities passed as argument to the state
815    they had just before their setup phase. It will not work entirely for
816    autosetup entities as we need to remove the autosetup triggers.
817
818    As of now, this function is *not* functional in that it doesn't revert to
819    the exact same state the entities were before setup. For example, the
820    properties do not work yet as those would need to be regenerated (since the
821    columns they are based on are regenerated too -- and as such the
822    corresponding joins are not correct) but this doesn't happen because of
823    the way relationship setup is designed to be called only once (especially
824    the backref stuff in create_properties).
825    """
826    for entity in entities:
827        desc = entity._descriptor
828        if desc.autosetup:
829            _cleanup_autosetup_triggers(entity)
830
831        if hasattr(entity, '_setup_done'):
832            del entity._setup_done
833
834        entity.table = None
835        entity.mapper = None
836       
837        desc._pk_col_done = False
838        desc.has_pk = False
839        desc._columns = []
840        desc.constraints = []
841        desc.properties = {}
842
843
844class Entity(object):
845    '''
846    The base class for all entities
847   
848    All Elixir model objects should inherit from this class. Statements can
849    appear within the body of the definition of an entity to define its
850    fields, relationships, and other options.
851   
852    Here is an example:
853
854    .. sourcecode:: python
855   
856        class Person(Entity):
857            name = Field(Unicode(128))
858            birthdate = Field(DateTime, default=datetime.now)
859   
860    Please note, that if you don't specify any primary keys, Elixir will
861    automatically create one called ``id``.
862   
863    For further information, please refer to the provided examples or
864    tutorial.
865    '''
866    __metaclass__ = EntityMeta
867   
868    def __init__(self, **kwargs):
869        for key, value in kwargs.items():
870            setattr(self, key, value)
871
872    def set(self, **kwargs):
873        self.from_dict(kwargs)
874
875    def from_dict(self, data):
876        """
877        Update a mapped class with data from a JSON-style nested dict/list
878        structure.
879        """
880        mapper = sqlalchemy.orm.object_mapper(self)
881        session = sqlalchemy.orm.object_session(self)
882
883        for col in mapper.mapped_table.c:
884            if not col.primary_key and data.has_key(col.name):
885                setattr(self, col.name, data[col.name])
886
887        for rel in mapper.iterate_properties:
888            rname = rel.key
889            if isinstance(rel, sqlalchemy.orm.properties.PropertyLoader) \
890                    and data.has_key(rname):
891                dbdata = getattr(self, rname)
892                if rel.uselist:
893                    pkey = [c for c in rel.table.columns if c.primary_key]
894
895                    # Build a lookup dict: {(pk1, pk2): value}
896                    lookup = dict([
897                        (tuple([getattr(o, c.name) for c in pkey]), o) 
898                        for o in dbdata])
899                    for row in data[rname]:
900                        # If any primary key columns are missing or None,
901                        # create a new object
902                        if [1 for c in pkey if not row.get(c.name)]:
903                            subobj = rel.mapper.class_()
904                            dbdata.append(subobj)
905                        else:
906                            key = tuple([row[c.name] for c in pkey])
907                            subobj = lookup.pop(key, None)
908
909                            # If the row isn't found, we must fail the request
910                            # in a web scenario, this could be a parameter
911                            # tampering attack
912                            if not subobj:
913                                raise sqlalchemy.exceptions.ArgumentError(
914                                        '%s row not found in database: %s' \
915                                        % (rname, repr(row)))
916                        subobj.from_dict(row)
917
918                    # Make sure the object list attribute doesn't contain any
919                    # old value (which are not present in the new data).
920                    for delobj in lookup.itervalues():
921                        dbdata.remove(delobj)
922                        session.delete(delobj)
923                else:
924                    if data[rname] is None:
925                        setattr(self, rname, None)
926                    else:
927                        if not dbdata:
928                            dbdata = rel.mapper.class_()
929                            setattr(self, rname, dbdata)
930                        dbdata.from_dict(data[rname])
931
932    def to_dict(self, deep={}, exclude=[]):
933        """Generate a JSON-style nested dict/list structure from an object."""
934        columns = []
935        for table in self.mapper.tables:
936            for col in table.c:
937                columns.append(col)
938           
939        data = dict([(col.name, getattr(self, col.name))
940                     for col in columns if col.name not in exclude])
941        for rname, rdeep in deep.iteritems():
942            dbdata = getattr(self, rname)
943            fks = self.mapper.get_property(rname).foreign_keys
944            exclude = [c.name for c in fks]
945            if isinstance(dbdata, list):
946                data[rname] = [o.to_dict(rdeep, exclude) for o in dbdata]
947            else:
948                data[rname] = dbdata.to_dict(rdeep, exclude)
949        return data
950
951    # session methods
952    def flush(self, *args, **kwargs):
953        return object_session(self).flush([self], *args, **kwargs)
954
955    def delete(self, *args, **kwargs):
956        return object_session(self).delete(self, *args, **kwargs)
957
958    def expire(self, *args, **kwargs):
959        return object_session(self).expire(self, *args, **kwargs)
960
961    def refresh(self, *args, **kwargs):
962        return object_session(self).refresh(self, *args, **kwargs)
963
964    def expunge(self, *args, **kwargs):
965        return object_session(self).expunge(self, *args, **kwargs)
966
967    # This bunch of session methods, along with all the query methods below
968    # only make sense when using a global/scoped/contextual session.
969    def _global_session(self):
970        return self._descriptor.session.registry()
971    _global_session = property(_global_session)
972
973    def merge(self, *args, **kwargs):
974        return self._global_session.merge(self, *args, **kwargs)
975
976    def save(self, *args, **kwargs):
977        return self._global_session.save(self, *args, **kwargs)
978
979    def update(self, *args, **kwargs):
980        return self._global_session.update(self, *args, **kwargs)
981
982    def save_or_update(self, *args, **kwargs):
983        return self._global_session.save_or_update(self, *args, **kwargs)
984
985    # query methods
986    def get_by(cls, *args, **kwargs):
987        return cls.query.filter_by(*args, **kwargs).first()
988    get_by = classmethod(get_by)
989
990    def get(cls, *args, **kwargs):
991        return cls.query.get(*args, **kwargs)
992    get = classmethod(get)
993
994    #-----------------#
995    # DEPRECATED LAND #
996    #-----------------#
997
998    def filter(cls, *args, **kwargs):
999        warnings.warn("The filter method on the class is deprecated."
1000                      "You should use cls.query.filter(...)", 
1001                      DeprecationWarning, stacklevel=2)
1002        return cls.query.filter(*args, **kwargs)
1003    filter = classmethod(filter)
1004
1005    def filter_by(cls, *args, **kwargs):
1006        warnings.warn("The filter_by method on the class is deprecated."
1007                      "You should use cls.query.filter_by(...)", 
1008                      DeprecationWarning, stacklevel=2)
1009        return cls.query.filter_by(*args, **kwargs)
1010    filter_by = classmethod(filter_by)
1011
1012    def select(cls, *args, **kwargs):
1013        warnings.warn("The select method on the class is deprecated."
1014                      "You should use cls.query.filter(...).all()", 
1015                      DeprecationWarning, stacklevel=2)
1016        return cls.query.filter(*args, **kwargs).all()
1017    select = classmethod(select)
1018
1019    def select_by(cls, *args, **kwargs):
1020        warnings.warn("The select_by method on the class is deprecated."
1021                      "You should use cls.query.filter_by(...).all()", 
1022                      DeprecationWarning, stacklevel=2)
1023        return cls.query.filter_by(*args, **kwargs).all()
1024    select_by = classmethod(select_by)
1025
1026    def selectfirst(cls, *args, **kwargs):
1027        warnings.warn("The selectfirst method on the class is deprecated."
1028                      "You should use cls.query.filter(...).first()", 
1029                      DeprecationWarning, stacklevel=2)
1030        return cls.query.filter(*args, **kwargs).first()
1031    selectfirst = classmethod(selectfirst)
1032
1033    def selectfirst_by(cls, *args, **kwargs):
1034        warnings.warn("The selectfirst_by method on the class is deprecated."
1035                      "You should use cls.query.filter_by(...).first()", 
1036                      DeprecationWarning, stacklevel=2)
1037        return cls.query.filter_by(*args, **kwargs).first()
1038    selectfirst_by = classmethod(selectfirst_by)
1039
1040    def selectone(cls, *args, **kwargs):
1041        warnings.warn("The selectone method on the class is deprecated."
1042                      "You should use cls.query.filter(...).one()", 
1043                      DeprecationWarning, stacklevel=2)
1044        return cls.query.filter(*args, **kwargs).one()
1045    selectone = classmethod(selectone)
1046
1047    def selectone_by(cls, *args, **kwargs):
1048        warnings.warn("The selectone_by method on the class is deprecated."
1049                      "You should use cls.query.filter_by(...).one()", 
1050                      DeprecationWarning, stacklevel=2)
1051        return cls.query.filter_by(*args, **kwargs).one()
1052    selectone_by = classmethod(selectone_by)
1053
1054    def join_to(cls, *args, **kwargs):
1055        warnings.warn("The join_to method on the class is deprecated."
1056                      "You should use cls.query.join(...)", 
1057                      DeprecationWarning, stacklevel=2)
1058        return cls.query.join_to(*args, **kwargs).all()
1059    join_to = classmethod(join_to)
1060
1061    def join_via(cls, *args, **kwargs):
1062        warnings.warn("The join_via method on the class is deprecated."
1063                      "You should use cls.query.join(...)", 
1064                      DeprecationWarning, stacklevel=2)
1065        return cls.query.join_via(*args, **kwargs).all()
1066    join_via = classmethod(join_via)
1067
1068    def count(cls, *args, **kwargs):
1069        warnings.warn("The count method on the class is deprecated."
1070                      "You should use cls.query.filter(...).count()", 
1071                      DeprecationWarning, stacklevel=2)
1072        return cls.query.filter(*args, **kwargs).count()
1073    count = classmethod(count)
1074
1075    def count_by(cls, *args, **kwargs):
1076        warnings.warn("The count_by method on the class is deprecated."
1077                      "You should use cls.query.filter_by(...).count()", 
1078                      DeprecationWarning, stacklevel=2)
1079        return cls.query.filter_by(*args, **kwargs).count()
1080    count_by = classmethod(count_by)
1081
1082    def options(cls, *args, **kwargs):
1083        warnings.warn("The options method on the class is deprecated."
1084                      "You should use cls.query.options(...)", 
1085                      DeprecationWarning, stacklevel=2)
1086        return cls.query.options(*args, **kwargs)
1087    options = classmethod(options)
1088
1089    def instances(cls, *args, **kwargs):
1090        warnings.warn("The instances method on the class is deprecated."
1091                      "You should use cls.query.instances(...)", 
1092                      DeprecationWarning, stacklevel=2)
1093        return cls.query.instances(*args, **kwargs)
1094    instances = classmethod(instances)
1095
1096
Note: See TracBrowser for help on using the browser.