root / elixir / trunk / elixir / entity.py

Revision 534, 38.5 kB (checked in by ged, 12 months ago)

- Fixed bad foreign key constraint generated for classes inheriting from a

class with multiple primary keys when using the "multi" inheritance.
Patch from & closes #114.

Line 
1'''
2This module provides the ``Entity`` base class, as well as its metaclass
3``EntityMeta``.
4'''
5
6import sys
7import types
8import warnings
9
10from copy import deepcopy
11
12import sqlalchemy
13from sqlalchemy import Table, Column, Integer, desc, ForeignKey, and_, \
14                       ForeignKeyConstraint
15from sqlalchemy.orm import MapperExtension, mapper, object_session, \
16                           EXT_CONTINUE, polymorphic_union, ScopedSession, \
17                           ColumnProperty
18from sqlalchemy.sql import ColumnCollection
19
20import elixir
21from elixir.statements import process_mutators, MUTATORS
22from elixir import options
23from elixir.properties import Property
24
25DEBUG = False
26
27__doc_all__ = ['Entity', 'EntityMeta']
28
29
30def session_mapper_factory(scoped_session):
31    def session_mapper(cls, *args, **kwargs):
32        if kwargs.pop('save_on_init', True):
33            old_init = cls.__init__
34            def __init__(self, *args, **kwargs):
35                old_init(self, *args, **kwargs)
36                scoped_session.add(self)
37            cls.__init__ = __init__
38        cls.query = scoped_session.query_property()
39        return mapper(cls, *args, **kwargs)
40    return session_mapper
41
42
43class EntityDescriptor(object):
44    '''
45    EntityDescriptor describes fields and options needed for table creation.
46    '''
47
48    def __init__(self, entity):
49        self.entity = entity
50        self.parent = None
51
52        bases = []
53        for base in entity.__bases__:
54            if isinstance(base, EntityMeta):
55                if is_entity(base) and not is_abstract_entity(base):
56                    if self.parent:
57                        raise Exception(
58                            '%s entity inherits from several entities, '
59                            'and this is not supported.'
60                            % self.entity.__name__)
61                    else:
62                        self.parent = base
63                        bases.extend(base._descriptor.bases)
64                        self.parent._descriptor.children.append(entity)
65                else:
66                    bases.append(base)
67        self.bases = bases
68        if not is_entity(entity) or is_abstract_entity(entity):
69            return
70
71        # entity.__module__ is not always reliable (eg in mod_python)
72        self.module = sys.modules.get(entity.__module__)
73
74        self.builders = []
75
76        #XXX: use entity.__subclasses__ ?
77        self.children = []
78
79        # used for multi-table inheritance
80        self.join_condition = None
81        self.has_pk = False
82        self._pk_col_done = False
83
84        # columns and constraints waiting for a table to exist
85        self._columns = ColumnCollection()
86        self.constraints = []
87
88        # properties (it is only useful for checking dupe properties at the
89        # moment, and when adding properties before the mapper is created,
90        # which shouldn't happen).
91        self.properties = {}
92
93        #
94        self.relationships = []
95
96        # set default value for options
97        self.table_args = []
98
99        # base class(es) options_defaults
100        options_defaults = self.options_defaults()
101
102        complete_defaults = options.options_defaults.copy()
103        complete_defaults.update({
104            'metadata': elixir.metadata,
105            'session': elixir.session,
106            'collection': elixir.entities
107        })
108
109        # set default value for other options
110        for key in options.valid_options:
111            value = options_defaults.get(key, complete_defaults[key])
112            if isinstance(value, dict):
113                value = value.copy()
114            setattr(self, key, value)
115
116        # override options with module-level defaults defined
117        for key in ('metadata', 'session', 'collection'):
118            attr = '__%s__' % key
119            if hasattr(self.module, attr):
120                setattr(self, key, getattr(self.module, attr))
121
122    def options_defaults(self):
123        base_defaults = {}
124        for base in self.bases:
125            base_defaults.update(base._descriptor.options_defaults())
126        base_defaults.update(getattr(self.entity, 'options_defaults', {}))
127        return base_defaults
128
129    def setup_options(self):
130        '''
131        Setup any values that might depend on the "using_options" class
132        mutator. For example, the tablename or the metadata.
133        '''
134        elixir.metadatas.add(self.metadata)
135        if self.collection is not None:
136            self.collection.append(self.entity)
137
138        entity = self.entity
139        if self.parent:
140            if self.inheritance == 'single':
141                self.tablename = self.parent._descriptor.tablename
142
143        if not self.tablename:
144            if self.shortnames:
145                self.tablename = entity.__name__.lower()
146            else:
147                modulename = entity.__module__.replace('.', '_')
148                tablename = "%s_%s" % (modulename, entity.__name__)
149                self.tablename = tablename.lower()
150        elif hasattr(self.tablename, '__call__'):
151            self.tablename = self.tablename(entity)
152
153        if not self.identity:
154            if 'polymorphic_identity' in self.mapper_options:
155                self.identity = self.mapper_options['polymorphic_identity']
156            else:
157                #TODO: include module name (We could have b.Account inherit
158                # from a.Account)
159                self.identity = entity.__name__.lower()
160        elif 'polymorphic_identity' in self.mapper_options:
161            raise Exception('You cannot use the "identity" option and the '
162                            'polymorphic_identity mapper option at the same '
163                            'time.')
164        elif hasattr(self.identity, '__call__'):
165            self.identity = self.identity(entity)
166
167        if self.polymorphic:
168            if not isinstance(self.polymorphic, basestring):
169                self.polymorphic = options.DEFAULT_POLYMORPHIC_COL_NAME
170
171    #---------------------
172    # setup phase methods
173
174    def setup_autoload_table(self):
175        self.setup_table(True)
176
177    def create_pk_cols(self):
178        """
179        Create primary_key columns. That is, call the 'create_pk_cols'
180        builders then add a primary key to the table if it hasn't already got
181        one and needs one.
182
183        This method is "semi-recursive" in some cases: it calls the
184        create_keys method on ManyToOne relationships and those in turn call
185        create_pk_cols on their target. It shouldn't be possible to have an
186        infinite loop since a loop of primary_keys is not a valid situation.
187        """
188        if self._pk_col_done:
189            return
190
191        self.call_builders('create_pk_cols')
192
193        if not self.autoload:
194            if self.parent:
195                if self.inheritance == 'multi':
196                    # Add columns with foreign keys to the parent's primary
197                    # key columns
198                    parent_desc = self.parent._descriptor
199                    tablename = parent_desc.table_fullname
200                    join_clauses = []
201                    fk_columns = []
202                    for pk_col in parent_desc.primary_keys:
203                        colname = options.MULTIINHERITANCECOL_NAMEFORMAT % \
204                                  {'entity': self.parent.__name__.lower(),
205                                   'key': pk_col.key}
206
207                        # It seems like SA ForeignKey is not happy being given
208                        # a real column object when said column is not yet
209                        # attached to a table
210                        pk_col_name = "%s.%s" % (tablename, pk_col.key)
211                        col = Column(colname, pk_col.type, primary_key=True)
212                        fk_columns.append(col)
213                        self.add_column(col)
214                        join_clauses.append(col == pk_col)
215                    self.join_condition = and_(*join_clauses)
216                    self.add_constraint(
217                        ForeignKeyConstraint(fk_columns,
218                            parent_desc.primary_keys, ondelete='CASCADE'))
219                elif self.inheritance == 'concrete':
220                    # Copy primary key columns from the parent.
221                    for col in self.parent._descriptor.columns:
222                        if col.primary_key:
223                            self.add_column(col.copy())
224            elif not self.has_pk and self.auto_primarykey:
225                if isinstance(self.auto_primarykey, basestring):
226                    colname = self.auto_primarykey
227                else:
228                    colname = options.DEFAULT_AUTO_PRIMARYKEY_NAME
229
230                self.add_column(
231                    Column(colname, options.DEFAULT_AUTO_PRIMARYKEY_TYPE,
232                           primary_key=True))
233        self._pk_col_done = True
234
235    def setup_relkeys(self):
236        self.call_builders('create_non_pk_cols')
237
238    def before_table(self):
239        self.call_builders('before_table')
240
241    def setup_table(self, only_autoloaded=False):
242        '''
243        Create a SQLAlchemy table-object with all columns that have been
244        defined up to this point.
245        '''
246        if self.entity.table is not None:
247            return
248
249        if self.autoload != only_autoloaded:
250            return
251
252        kwargs = self.table_options
253        if self.autoload:
254            args = self.table_args
255            kwargs['autoload'] = True
256        else:
257            if self.parent:
258                if self.inheritance == 'single':
259                    # we know the parent is setup before the child
260                    self.entity.table = self.parent.table
261
262                    # re-add the entity columns to the parent entity so that
263                    # they are added to the parent's table (whether the
264                    # parent's table is already setup or not).
265                    for col in self._columns:
266                        self.parent._descriptor.add_column(col)
267                    for constraint in self.constraints:
268                        self.parent._descriptor.add_constraint(constraint)
269                    return
270                elif self.inheritance == 'concrete':
271                    #TODO: we should also copy columns from the parent table
272                    # if the parent is a base (abstract?) entity (whatever the
273                    # inheritance type -> elif will need to be changed)
274
275                    # Copy all non-primary key columns from parent table
276                    # (primary key columns have already been copied earlier).
277                    for col in self.parent._descriptor.columns:
278                        if not col.primary_key:
279                            self.add_column(col.copy())
280
281                    for con in self.parent._descriptor.constraints:
282                        self.add_constraint(
283                            ForeignKeyConstraint(
284                                [e.parent.key for e in con.elements],
285                                [e.target_fullname for e in con.elements],
286                                name=con.name, #TODO: modify it
287                                onupdate=con.onupdate, ondelete=con.ondelete,
288                                use_alter=con.use_alter))
289
290            if self.polymorphic and \
291               self.inheritance in ('single', 'multi') and \
292               self.children and not self.parent:
293                self.add_column(Column(self.polymorphic,
294                                       options.POLYMORPHIC_COL_TYPE))
295
296            if self.version_id_col:
297                if not isinstance(self.version_id_col, basestring):
298                    self.version_id_col = options.DEFAULT_VERSION_ID_COL_NAME
299                self.add_column(Column(self.version_id_col, Integer))
300
301            args = list(self.columns) + self.constraints + self.table_args
302        self.entity.table = Table(self.tablename, self.metadata,
303                                  *args, **kwargs)
304        if DEBUG:
305            print self.entity.table.repr2()
306
307    def setup_reltables(self):
308        self.call_builders('create_tables')
309
310    def after_table(self):
311        self.call_builders('after_table')
312
313    def setup_events(self):
314        def make_proxy_method(methods):
315            def proxy_method(self, mapper, connection, instance):
316                for func in methods:
317                    ret = func(instance)
318                    # I couldn't commit myself to force people to
319                    # systematicaly return EXT_CONTINUE in all their event
320                    # methods.
321                    # But not doing that diverge to how SQLAlchemy works.
322                    # I should try to convince Mike to do EXT_CONTINUE by
323                    # default, and stop processing as the special case.
324#                    if ret != EXT_CONTINUE:
325                    if ret is not None and ret != EXT_CONTINUE:
326                        return ret
327                return EXT_CONTINUE
328            return proxy_method
329
330        # create a list of callbacks for each event
331        methods = {}
332
333        all_methods = getmembers(self.entity,
334                                 lambda a: isinstance(a, types.MethodType))
335
336        for name, method in all_methods:
337            for event in getattr(method, '_elixir_events', []):
338                event_methods = methods.setdefault(event, [])
339                event_methods.append(method)
340
341        if not methods:
342            return
343
344        # transform that list into methods themselves
345        for event in methods:
346            methods[event] = make_proxy_method(methods[event])
347
348        # create a custom mapper extension class, tailored to our entity
349        ext = type('EventMapperExtension', (MapperExtension,), methods)()
350
351        # then, make sure that the entity's mapper has our mapper extension
352        self.add_mapper_extension(ext)
353
354    def before_mapper(self):
355        self.call_builders('before_mapper')
356
357    def _get_children(self):
358        children = self.children[:]
359        for child in self.children:
360            children.extend(child._descriptor._get_children())
361        return children
362
363    def translate_order_by(self, order_by):
364        if isinstance(order_by, basestring):
365            order_by = [order_by]
366
367        order = []
368        for colname in order_by:
369            #FIXME: get_column uses self.columns[key] instead of property
370            # names. self.columns correspond to the columns of the table if
371            # the table was already created and to self._columns otherwise,
372            # which is a ColumnCollection indexed on columns.key
373            # See ticket #108.
374            col = self.get_column(colname.strip('-'))
375            if colname.startswith('-'):
376                col = desc(col)
377            order.append(col)
378        return order
379
380    def setup_mapper(self):
381        '''
382        Initializes and assign a mapper to the entity.
383        At this point the mapper will usually have no property as they are
384        added later.
385        '''
386        if self.entity.mapper:
387            return
388
389        # for now we don't support the "abstract" parent class in a concrete
390        # inheritance scenario as demonstrated in
391        # sqlalchemy/test/orm/inheritance/concrete.py
392        # this should be added along other
393        kwargs = {}
394        if self.order_by:
395            kwargs['order_by'] = self.translate_order_by(self.order_by)
396
397        if self.version_id_col:
398            kwargs['version_id_col'] = self.get_column(self.version_id_col)
399
400        if self.inheritance in ('single', 'concrete', 'multi'):
401            if self.parent and \
402               (self.inheritance != 'concrete' or self.polymorphic):
403                # non-polymorphic concrete doesn't need this
404                kwargs['inherits'] = self.parent.mapper
405
406            if self.inheritance == 'multi' and self.parent:
407                kwargs['inherit_condition'] = self.join_condition
408
409            if self.polymorphic:
410                if self.children:
411                    if self.inheritance == 'concrete':
412                        keys = [(self.identity, self.entity.table)]
413                        keys.extend([(child._descriptor.identity, child.table)
414                                     for child in self._get_children()])
415                        # Having the same alias name for an entity and one of
416                        # its child (which is a parent itself) shouldn't cause
417                        # any problem because the join shouldn't be used at
418                        # the same time. But in reality, some versions of SA
419                        # do misbehave on this. Since it doesn't hurt to have
420                        # different names anyway, here they go.
421                        pjoin = polymorphic_union(
422                                    dict(keys), self.polymorphic,
423                                    'pjoin_%s' % self.identity)
424
425                        kwargs['with_polymorphic'] = ('*', pjoin)
426                        kwargs['polymorphic_on'] = \
427                            getattr(pjoin.c, self.polymorphic)
428                    elif not self.parent:
429                        kwargs['polymorphic_on'] = \
430                            self.get_column(self.polymorphic)
431
432                if self.children or self.parent:
433                    kwargs['polymorphic_identity'] = self.identity
434
435                if self.parent and self.inheritance == 'concrete':
436                    kwargs['concrete'] = True
437
438        if self.parent and self.inheritance == 'single':
439            args = []
440        else:
441            args = [self.entity.table]
442
443        # let user-defined kwargs override Elixir-generated ones, though that's
444        # not very usefull since most of them expect Column instances.
445        kwargs.update(self.mapper_options)
446
447        #TODO: document this!
448        if 'primary_key' in kwargs:
449            cols = self.entity.table.c
450            kwargs['primary_key'] = [getattr(cols, colname) for
451                colname in kwargs['primary_key']]
452
453        # do the mapping
454        if self.session is None:
455            self.entity.mapper = mapper(self.entity, *args, **kwargs)
456        elif isinstance(self.session, ScopedSession):
457            session_mapper = session_mapper_factory(self.session)
458            self.entity.mapper = session_mapper(self.entity, *args, **kwargs)
459        else:
460            raise Exception("Failed to map entity '%s' with its table or "
461                            "selectable. You can only bind an Entity to a "
462                            "ScopedSession object or None for manual session "
463                            "management."
464                            % self.entity.__name__)
465
466    def after_mapper(self):
467        self.call_builders('after_mapper')
468
469    def setup_properties(self):
470        self.call_builders('create_properties')
471
472    def finalize(self):
473        self.call_builders('finalize')
474        self.entity._setup_done = True
475
476    #----------------
477    # helper methods
478
479    def call_builders(self, what):
480        for builder in self.builders:
481            if hasattr(builder, what):
482                getattr(builder, what)()
483
484    def add_column(self, col, check_duplicate=None):
485        '''when check_duplicate is None, the value of the allowcoloverride
486        option of the entity is used.
487        '''
488        if check_duplicate is None:
489            check_duplicate = not self.allowcoloverride
490
491        if col.key in self._columns:
492            if check_duplicate:
493                raise Exception("Column '%s' already exist in '%s' ! " %
494                                (col.key, self.entity.__name__))
495            else:
496                del self._columns[col.key]
497        # are indexed on col.key
498        self._columns.add(col)
499
500        if col.primary_key:
501            self.has_pk = True
502
503        table = self.entity.table
504        if table is not None:
505            if check_duplicate and col.key in table.columns.keys():
506                raise Exception("Column '%s' already exist in table '%s' ! " %
507                                (col.key, table.name))
508            table.append_column(col)
509            if DEBUG:
510                print "table.append_column(%s)" % col
511
512    def add_constraint(self, constraint):
513        self.constraints.append(constraint)
514
515        table = self.entity.table
516        if table is not None:
517            table.append_constraint(constraint)
518
519    def add_property(self, name, property, check_duplicate=True):
520        if check_duplicate and name in self.properties:
521            raise Exception("property '%s' already exist in '%s' ! " %
522                            (name, self.entity.__name__))
523        self.properties[name] = property
524
525#FIXME: something like this is needed to propagate the relationships from
526# parent entities to their children in a concrete inheritance scenario. But
527# this doesn't work because of the backref matching code. In most case
528# (test_concrete.py) it doesn't even happen at all.
529#        if self.children and self.inheritance == 'concrete':
530#            for child in self.children:
531#                child._descriptor.add_property(name, property)
532
533        mapper = self.entity.mapper
534        if mapper:
535            mapper.add_property(name, property)
536            if DEBUG:
537                print "mapper.add_property('%s', %s)" % (name, repr(property))
538
539    def add_mapper_extension(self, extension):
540        extensions = self.mapper_options.get('extension', [])
541        if not isinstance(extensions, list):
542            extensions = [extensions]
543        extensions.append(extension)
544        self.mapper_options['extension'] = extensions
545
546    def get_column(self, key, check_missing=True):
547        #TODO: this needs to work whether the table is already setup or not
548        #TODO: support SA table/autoloaded entity
549        try:
550            return self.columns[key]
551        except KeyError:
552            if check_missing:
553                raise Exception("No column named '%s' found in the table of "
554                                "the '%s' entity!"
555                                % (key, self.entity.__name__))
556
557    def get_inverse_relation(self, rel, check_reverse=True):
558        '''
559        Return the inverse relation of rel, if any, None otherwise.
560        '''
561
562        matching_rel = None
563        for other_rel in self.relationships:
564            if rel.is_inverse(other_rel):
565                if matching_rel is None:
566                    matching_rel = other_rel
567                else:
568                    raise Exception(
569                            "Several relations match as inverse of the '%s' "
570                            "relation in entity '%s'. You should specify "
571                            "inverse relations manually by using the inverse "
572                            "keyword."
573                            % (rel.name, rel.entity.__name__))
574        # When a matching inverse is found, we check that it has only
575        # one relation matching as its own inverse. We don't need the result
576        # of the method though. But we do need to be careful not to start an
577        # infinite recursive loop.
578        if matching_rel and check_reverse:
579            rel.entity._descriptor.get_inverse_relation(matching_rel, False)
580
581        return matching_rel
582
583    def find_relationship(self, name):
584        for rel in self.relationships:
585            if rel.name == name:
586                return rel
587        if self.parent:
588            return self.parent._descriptor.find_relationship(name)
589        else:
590            return None
591
592    #------------------------
593    # some useful properties
594
595    @property
596    def table_fullname(self):
597        '''
598        Complete name of the table for the related entity.
599        Includes the schema name if there is one specified.
600        '''
601        schema = self.table_options.get('schema', None)
602        if schema is not None:
603            return "%s.%s" % (schema, self.tablename)
604        else:
605            return self.tablename
606
607    @property
608    def columns(self):
609        if self.entity.table is not None:
610            return self.entity.table.columns
611        else:
612            #FIXME: depending on the type of inheritance, we should also
613            # return the parent entity's columns (for example for order_by
614            # using a column defined in the parent.
615            return self._columns
616
617    @property
618    def primary_keys(self):
619        """
620        Returns the list of primary key columns of the entity.
621
622        This property isn't valid before the "create_pk_cols" phase.
623        """
624        if self.autoload:
625            return [col for col in self.entity.table.primary_key.columns]
626        else:
627            if self.parent and self.inheritance == 'single':
628                return self.parent._descriptor.primary_keys
629            else:
630                return [col for col in self.columns if col.primary_key]
631
632    @property
633    def table(self):
634        if self.entity.table is not None:
635            return self.entity.table
636        else:
637            return FakeTable(self)
638
639    @property
640    def primary_key_properties(self):
641        """
642        Returns the list of (mapper) properties corresponding to the primary
643        key columns of the table of the entity.
644
645        This property caches its value, so it shouldn't be called before the
646        entity is fully set up.
647        """
648        if not hasattr(self, '_pk_props'):
649            col_to_prop = {}
650            mapper = self.entity.mapper
651            for prop in mapper.iterate_properties:
652                if isinstance(prop, ColumnProperty):
653                    for col in prop.columns:
654                        #XXX: Why is this extra loop necessary? What is this
655                        #     "proxy_set" supposed to mean?
656                        for col in col.proxy_set:
657                            col_to_prop[col] = prop
658            pk_cols = [c for c in mapper.mapped_table.c if c.primary_key]
659            self._pk_props = [col_to_prop[c] for c in pk_cols]
660        return self._pk_props
661
662class FakePK(object):
663    def __init__(self, descriptor):
664        self.descriptor = descriptor
665
666    @property
667    def columns(self):
668        return self.descriptor.primary_keys
669
670class FakeTable(object):
671    def __init__(self, descriptor):
672        self.descriptor = descriptor
673        self.primary_key = FakePK(descriptor)
674
675    @property
676    def columns(self):
677        return self.descriptor.columns
678
679    @property
680    def fullname(self):
681        '''
682        Complete name of the table for the related entity.
683        Includes the schema name if there is one specified.
684        '''
685        schema = self.descriptor.table_options.get('schema', None)
686        if schema is not None:
687            return "%s.%s" % (schema, self.descriptor.tablename)
688        else:
689            return self.descriptor.tablename
690
691
692def is_entity(cls):
693    """
694    Scan the bases classes of `cls` to see if any is an instance of
695    EntityMeta. If we don't find any, it means it is either an unrelated class
696    or an entity base class (like the 'Entity' class).
697    """
698    for base in cls.__bases__:
699        if isinstance(base, EntityMeta):
700            return True
701    return False
702
703
704# Note that we don't use inspect.getmembers because of
705# http://bugs.python.org/issue1785
706# See also http://elixir.ematia.de/trac/changeset/262
707def getmembers(object, predicate=None):
708    base_props = []
709    for key in dir(object):
710        try:
711            value = getattr(object, key)
712        except AttributeError:
713            continue
714        if not predicate or predicate(value):
715            base_props.append((key, value))
716    return base_props
717
718def is_abstract_entity(dict_or_cls):
719    if not isinstance(dict_or_cls, dict):
720        dict_or_cls = dict_or_cls.__dict__
721    for mutator, args, kwargs in dict_or_cls.get(MUTATORS, []):
722        if 'abstract' in kwargs:
723            return kwargs['abstract']
724
725    return False
726
727def instrument_class(cls):
728    """
729    Instrument a class as an Entity. This is usually done automatically through
730    the EntityMeta metaclass.
731    """
732    # Create the entity descriptor
733    desc = cls._descriptor = EntityDescriptor(cls)
734
735    # Process mutators
736    # We *do* want mutators to be processed for base/abstract classes
737    # (so that statements like using_options_defaults work).
738    process_mutators(cls)
739
740    # We do not want to do any more processing for base/abstract classes
741    # (Entity et al.).
742    if not is_entity(cls) or is_abstract_entity(cls):
743        return
744
745    cls.table = None
746    cls.mapper = None
747
748    # Copy the properties ('Property' instances) of the entity base class(es).
749    # We use getmembers (instead of __dict__) so that we also get the
750    # properties from the parents of the base class if any.
751    base_props = []
752    for base in cls.__bases__:
753        if isinstance(base, EntityMeta) and \
754           (not is_entity(base) or is_abstract_entity(base)):
755            base_props += [(name, deepcopy(attr)) for name, attr in
756                           getmembers(base, lambda a: isinstance(a, Property))]
757
758    # Process attributes (using the assignment syntax), looking for
759    # 'Property' instances and attaching them to this entity.
760    properties = [(name, attr) for name, attr in cls.__dict__.iteritems()
761                               if isinstance(attr, Property)]
762    sorted_props = sorted(base_props + properties,
763                          key=lambda i: i[1]._counter)
764    for name, prop in sorted_props:
765        prop.attach(cls, name)
766
767    # setup misc options here (like tablename etc.)
768    desc.setup_options()
769
770
771class EntityMeta(type):
772    """
773    Entity meta class.
774    You should only use it directly if you want to define your own base class
775    for your entities (ie you don't want to use the provided 'Entity' class).
776    """
777
778    def __init__(cls, name, bases, dict_):
779        instrument_class(cls)
780
781    def __setattr__(cls, key, value):
782        if isinstance(value, Property):
783            if hasattr(cls, '_setup_done'):
784                raise Exception('Cannot set attribute on a class after '
785                                'setup_all')
786            else:
787                value.attach(cls, key)
788        else:
789            type.__setattr__(cls, key, value)
790
791
792def setup_entities(entities):
793    '''Setup all entities in the list passed as argument'''
794
795    for entity in entities:
796        # delete all Elixir properties so that it doesn't interfere with
797        # SQLAlchemy. At this point they should have be converted to
798        # builders.
799        for name, attr in entity.__dict__.items():
800            if isinstance(attr, Property):
801                delattr(entity, name)
802
803    for method_name in (
804            'setup_autoload_table', 'create_pk_cols', 'setup_relkeys',
805            'before_table', 'setup_table', 'setup_reltables', 'after_table',
806            'setup_events',
807            'before_mapper', 'setup_mapper', 'after_mapper',
808            'setup_properties',
809            'finalize'):
810#        if DEBUG:
811#            print "=" * 40
812#            print method_name
813#            print "=" * 40
814        for entity in entities:
815#            print entity.__name__, "...",
816            if hasattr(entity, '_setup_done'):
817#                print "already done"
818                continue
819            method = getattr(entity._descriptor, method_name)
820            method()
821#            print "ok"
822
823
824def cleanup_entities(entities):
825    """
826    Try to revert back the list of entities passed as argument to the state
827    they had just before their setup phase.
828
829    As of now, this function is *not* functional in that it doesn't revert to
830    the exact same state the entities were before setup. For example, the
831    properties do not work yet as those would need to be regenerated (since the
832    columns they are based on are regenerated too -- and as such the
833    corresponding joins are not correct) but this doesn't happen because of
834    the way relationship setup is designed to be called only once (especially
835    the backref stuff in create_properties).
836    """
837    for entity in entities:
838        desc = entity._descriptor
839
840        if hasattr(entity, '_setup_done'):
841            del entity._setup_done
842
843        entity.table = None
844        entity.mapper = None
845
846        desc._pk_col_done = False
847        desc.has_pk = False
848        desc._columns = ColumnCollection()
849        desc.constraints = []
850        desc.properties = {}
851
852class EntityBase(object):
853    """
854    This class holds all methods of the "Entity" base class, but does not act
855    as a base class itself (it does not use the EntityMeta metaclass), but
856    rather as a parent class for Entity. This is meant so that people who want
857    to provide their own base class but don't want to loose or copy-paste all
858    the methods of Entity can do so by inheriting from EntityBase:
859
860    .. sourcecode:: python
861
862        class MyBase(EntityBase):
863            __metaclass__ = EntityMeta
864
865            def myCustomMethod(self):
866                # do something great
867    """
868
869    def __init__(self, **kwargs):
870        self.set(**kwargs)
871
872    def set(self, **kwargs):
873        for key, value in kwargs.iteritems():
874            setattr(self, key, value)
875
876    @classmethod
877    def update_or_create(cls, data, surrogate=True):
878        pk_props = cls._descriptor.primary_key_properties
879
880        # if all pk are present and not None
881        if not [1 for p in pk_props if data.get(p.key) is None]:
882            pk_tuple = tuple([data[prop.key] for prop in pk_props])
883            record = cls.query.get(pk_tuple)
884            if record is None:
885                if surrogate:
886                    raise Exception("Cannot create surrogate with pk")
887                else:
888                    record = cls()
889        else:
890            if surrogate:
891                record = cls()
892            else:
893                raise Exception("Cannot create non surrogate without pk")
894        record.from_dict(data)
895        return record
896
897    def from_dict(self, data):
898        """
899        Update a mapped class with data from a JSON-style nested dict/list
900        structure.
901        """
902        # surrogate can be guessed from autoincrement/sequence but I guess
903        # that's not 100% reliable, so we'll need an override
904
905        mapper = sqlalchemy.orm.object_mapper(self)
906
907        for key, value in data.iteritems():
908            if isinstance(value, dict):
909                dbvalue = getattr(self, key)
910                rel_class = mapper.get_property(key).mapper.class_
911                pk_props = rel_class._descriptor.primary_key_properties
912
913                # If the data doesn't contain any pk, and the relationship
914                # already has a value, update that record.
915                if not [1 for p in pk_props if p.key in data] and \
916                   dbvalue is not None:
917                    dbvalue.from_dict(value)
918                else:
919                    record = rel_class.update_or_create(value)
920                    setattr(self, key, record)
921            elif isinstance(value, list) and \
922                 value and isinstance(value[0], dict):
923
924                rel_class = mapper.get_property(key).mapper.class_
925                new_attr_value = []
926                for row in value:
927                    if not isinstance(row, dict):
928                        raise Exception(
929                                'Cannot send mixed (dict/non dict) data '
930                                'to list relationships in from_dict data.')
931                    record = rel_class.update_or_create(row)
932                    new_attr_value.append(record)
933                setattr(self, key, new_attr_value)
934            else:
935                setattr(self, key, value)
936
937    def to_dict(self, deep={}, exclude=[]):
938        """Generate a JSON-style nested dict/list structure from an object."""
939        col_prop_names = [p.key for p in self.mapper.iterate_properties \
940                                      if isinstance(p, ColumnProperty)]
941        data = dict([(name, getattr(self, name))
942                     for name in col_prop_names if name not in exclude])
943        for rname, rdeep in deep.iteritems():
944            dbdata = getattr(self, rname)
945            #FIXME: use attribute names (ie coltoprop) instead of column names
946            fks = self.mapper.get_property(rname).remote_side
947            exclude = [c.name for c in fks]
948            if dbdata is None:
949                data[rname] = None
950            elif isinstance(dbdata, list):
951                data[rname] = [o.to_dict(rdeep, exclude) for o in dbdata]
952            else:
953                data[rname] = dbdata.to_dict(rdeep, exclude)
954        return data
955
956    # session methods
957    def flush(self, *args, **kwargs):
958        return object_session(self).flush([self], *args, **kwargs)
959
960    def delete(self, *args, **kwargs):
961        return object_session(self).delete(self, *args, **kwargs)
962
963    def expire(self, *args, **kwargs):
964        return object_session(self).expire(self, *args, **kwargs)
965
966    def refresh(self, *args, **kwargs):
967        return object_session(self).refresh(self, *args, **kwargs)
968
969    def expunge(self, *args, **kwargs):
970        return object_session(self).expunge(self, *args, **kwargs)
971
972    # This bunch of session methods, along with all the query methods below
973    # only make sense when using a global/scoped/contextual session.
974    @property
975    def _global_session(self):
976        return self._descriptor.session.registry()
977
978    #FIXME: remove all deprecated methods, possibly all of these
979    def merge(self, *args, **kwargs):
980        return self._global_session.merge(self, *args, **kwargs)
981
982    def save(self, *args, **kwargs):
983        return self._global_session.save(self, *args, **kwargs)
984
985    def update(self, *args, **kwargs):
986        return self._global_session.update(self, *args, **kwargs)
987
988    # only exist in SA < 0.5
989    # IMO, the replacement (session.add) doesn't sound good enough to be added
990    # here. For example: "o = Order(); o.add()" is not very telling. It's
991    # better to leave it as "session.add(o)"
992    def save_or_update(self, *args, **kwargs):
993        return self._global_session.save_or_update(self, *args, **kwargs)
994
995    # query methods
996    @classmethod
997    def get_by(cls, *args, **kwargs):
998        """
999        Returns the first instance of this class matching the given criteria.
1000        This is equivalent to:
1001        session.query(MyClass).filter_by(...).first()
1002        """
1003        return cls.query.filter_by(*args, **kwargs).first()
1004
1005    @classmethod
1006    def get(cls, *args, **kwargs):
1007        """
1008        Return the instance of this class based on the given identifier,
1009        or None if not found. This is equivalent to:
1010        session.query(MyClass).get(...)
1011        """
1012        return cls.query.get(*args, **kwargs)
1013
1014
1015class Entity(EntityBase):
1016    '''
1017    The base class for all entities
1018
1019    All Elixir model objects should inherit from this class. Statements can
1020    appear within the body of the definition of an entity to define its
1021    fields, relationships, and other options.
1022
1023    Here is an example:
1024
1025    .. sourcecode:: python
1026
1027        class Person(Entity):
1028            name = Field(Unicode(128))
1029            birthdate = Field(DateTime, default=datetime.now)
1030
1031    Please note, that if you don't specify any primary keys, Elixir will
1032    automatically create one called ``id``.
1033
1034    For further information, please refer to the provided examples or
1035    tutorial.
1036    '''
1037    __metaclass__ = EntityMeta
1038
Note: See TracBrowser for help on using the browser.