root / elixir / trunk / elixir / entity.py @ 512

Revision 512, 37.7 kB (checked in by ged, 5 years ago)

- allows abstract classes to set default options
- allow mixed type inheritance (classes can inherit from a concrete class and

abstract classes at the same time).

- cleaned up slightly the abstract patch

Line 
1'''
2This module provides the ``Entity`` base class, as well as its metaclass
3``EntityMeta``.
4'''
5
6import sys
7import types
8import warnings
9
10from copy import deepcopy
11
12import sqlalchemy
13from sqlalchemy import Table, Column, Integer, desc, ForeignKey, and_, \
14                       ForeignKeyConstraint
15from sqlalchemy.orm import MapperExtension, mapper, object_session, \
16                           EXT_CONTINUE, polymorphic_union, ScopedSession, \
17                           ColumnProperty
18from sqlalchemy.sql import ColumnCollection
19
20import elixir
21from elixir.statements import process_mutators, MUTATORS
22from elixir import options
23from elixir.properties import Property
24
25DEBUG = False
26
27__doc_all__ = ['Entity', 'EntityMeta']
28
29
30def session_mapper_factory(scoped_session):
31    def session_mapper(cls, *args, **kwargs):
32        if kwargs.pop('save_on_init', True):
33            old_init = cls.__init__
34            def __init__(self, *args, **kwargs):
35                old_init(self, *args, **kwargs)
36                scoped_session.add(self)
37            cls.__init__ = __init__
38        cls.query = scoped_session.query_property()
39        return mapper(cls, *args, **kwargs)
40    return session_mapper
41
42
43class EntityDescriptor(object):
44    '''
45    EntityDescriptor describes fields and options needed for table creation.
46    '''
47
48    def __init__(self, entity):
49        self.entity = entity
50        # entity.__module__ is not always reliable (eg in mod_python)
51        self.module = sys.modules.get(entity.__module__)
52
53        self.builders = []
54
55        self.parent = None
56        #XXX: use entity.__subclasses__ ?
57        self.children = []
58
59        bases = []
60        for base in entity.__bases__:
61            if isinstance(base, EntityMeta):
62                if is_entity(base) and not is_abstract_entity(base):
63                    if self.parent:
64                        raise Exception(
65                            '%s entity inherits from several entities, '
66                            'and this is not supported.'
67                            % self.entity.__name__)
68                    else:
69                        self.parent = base
70                        bases.extend(base._descriptor.bases)
71                        self.parent._descriptor.children.append(entity)
72                else:
73                    bases.append(base)
74        self.bases = bases
75
76        if not is_entity(entity):
77            return
78
79        # used for multi-table inheritance
80        self.join_condition = None
81        self.has_pk = False
82        self._pk_col_done = False
83
84        # columns and constraints waiting for a table to exist
85        self._columns = ColumnCollection()
86        self.constraints = []
87
88        # properties (it is only useful for checking dupe properties at the
89        # moment, and when adding properties before the mapper is created,
90        # which shouldn't happen).
91        self.properties = {}
92
93        #
94        self.relationships = []
95
96        # set default value for options
97        self.table_args = []
98
99        # base class(es) options_defaults
100        base_defaults = {}
101        for base in self.bases:
102            base_defaults.update(getattr(base, 'options_defaults', {}))
103
104        complete_defaults = options.options_defaults.copy()
105        complete_defaults.update({
106            'metadata': elixir.metadata,
107            'session': elixir.session,
108            'collection': elixir.entities
109        })
110
111        # set default value for other options
112        for key in options.valid_options:
113            value = base_defaults.get(key, complete_defaults[key])
114            if isinstance(value, dict):
115                value = value.copy()
116            setattr(self, key, value)
117
118        # override options with module-level defaults defined
119        for key in ('metadata', 'session', 'collection'):
120            attr = '__%s__' % key
121            if hasattr(self.module, attr):
122                setattr(self, key, getattr(self.module, attr))
123
124    def setup_options(self):
125        '''
126        Setup any values that might depend on the "using_options" class
127        mutator. For example, the tablename or the metadata.
128        '''
129        elixir.metadatas.add(self.metadata)
130        if self.collection is not None:
131            self.collection.append(self.entity)
132
133        entity = self.entity
134        if self.parent:
135            if self.inheritance == 'single':
136                self.tablename = self.parent._descriptor.tablename
137
138        if not self.tablename:
139            if self.shortnames:
140                self.tablename = entity.__name__.lower()
141            else:
142                modulename = entity.__module__.replace('.', '_')
143                tablename = "%s_%s" % (modulename, entity.__name__)
144                self.tablename = tablename.lower()
145        elif hasattr(self.tablename, '__call__'):
146            self.tablename = self.tablename(entity)
147
148        if not self.identity:
149            if 'polymorphic_identity' in self.mapper_options:
150                self.identity = self.mapper_options['polymorphic_identity']
151            else:
152                #TODO: include module name (We could have b.Account inherit
153                # from a.Account)
154                self.identity = entity.__name__.lower()
155        elif 'polymorphic_identity' in self.mapper_options:
156            raise Exception('You cannot use the "identity" option and the '
157                            'polymorphic_identity mapper option at the same '
158                            'time.')
159        elif hasattr(self.identity, '__call__'):
160            self.identity = self.identity(entity)
161
162        if self.polymorphic:
163            if not isinstance(self.polymorphic, basestring):
164                self.polymorphic = options.DEFAULT_POLYMORPHIC_COL_NAME
165
166    #---------------------
167    # setup phase methods
168
169    def setup_autoload_table(self):
170        self.setup_table(True)
171
172    def create_pk_cols(self):
173        """
174        Create primary_key columns. That is, call the 'create_pk_cols'
175        builders then add a primary key to the table if it hasn't already got
176        one and needs one.
177
178        This method is "semi-recursive" in some cases: it calls the
179        create_keys method on ManyToOne relationships and those in turn call
180        create_pk_cols on their target. It shouldn't be possible to have an
181        infinite loop since a loop of primary_keys is not a valid situation.
182        """
183        if self._pk_col_done:
184            return
185
186        self.call_builders('create_pk_cols')
187
188        if not self.autoload:
189            if self.parent:
190                if self.inheritance == 'multi':
191                    # Add columns with foreign keys to the parent's primary
192                    # key columns
193                    parent_desc = self.parent._descriptor
194                    tablename = parent_desc.table_fullname
195                    join_clauses = []
196                    for pk_col in parent_desc.primary_keys:
197                        colname = options.MULTIINHERITANCECOL_NAMEFORMAT % \
198                                  {'entity': self.parent.__name__.lower(),
199                                   'key': pk_col.key}
200
201                        # It seems like SA ForeignKey is not happy being given
202                        # a real column object when said column is not yet
203                        # attached to a table
204                        pk_col_name = "%s.%s" % (tablename, pk_col.key)
205                        fk = ForeignKey(pk_col_name, ondelete='cascade')
206                        col = Column(colname, pk_col.type, fk,
207                                     primary_key=True)
208                        self.add_column(col)
209                        join_clauses.append(col == pk_col)
210                    self.join_condition = and_(*join_clauses)
211                elif self.inheritance == 'concrete':
212                    # Copy primary key columns from the parent.
213                    for col in self.parent._descriptor.columns:
214                        if col.primary_key:
215                            self.add_column(col.copy())
216            elif not self.has_pk and self.auto_primarykey:
217                if isinstance(self.auto_primarykey, basestring):
218                    colname = self.auto_primarykey
219                else:
220                    colname = options.DEFAULT_AUTO_PRIMARYKEY_NAME
221
222                self.add_column(
223                    Column(colname, options.DEFAULT_AUTO_PRIMARYKEY_TYPE,
224                           primary_key=True))
225        self._pk_col_done = True
226
227    def setup_relkeys(self):
228        self.call_builders('create_non_pk_cols')
229
230    def before_table(self):
231        self.call_builders('before_table')
232
233    def setup_table(self, only_autoloaded=False):
234        '''
235        Create a SQLAlchemy table-object with all columns that have been
236        defined up to this point.
237        '''
238        if self.entity.table is not None:
239            return
240
241        if self.autoload != only_autoloaded:
242            return
243
244        kwargs = self.table_options
245        if self.autoload:
246            args = self.table_args
247            kwargs['autoload'] = True
248        else:
249            if self.parent:
250                if self.inheritance == 'single':
251                    # we know the parent is setup before the child
252                    self.entity.table = self.parent.table
253
254                    # re-add the entity columns to the parent entity so that
255                    # they are added to the parent's table (whether the
256                    # parent's table is already setup or not).
257                    for col in self._columns:
258                        self.parent._descriptor.add_column(col)
259                    for constraint in self.constraints:
260                        self.parent._descriptor.add_constraint(constraint)
261                    return
262                elif self.inheritance == 'concrete':
263                    #TODO: we should also copy columns from the parent table
264                    # if the parent is a base (abstract?) entity (whatever the
265                    # inheritance type -> elif will need to be changed)
266
267                    # Copy all non-primary key columns from parent table
268                    # (primary key columns have already been copied earlier).
269                    for col in self.parent._descriptor.columns:
270                        if not col.primary_key:
271                            self.add_column(col.copy())
272
273                    for con in self.parent._descriptor.constraints:
274                        self.add_constraint(
275                            ForeignKeyConstraint(
276                                [e.parent.key for e in con.elements],
277                                [e.target_fullname for e in con.elements],
278                                name=con.name, #TODO: modify it
279                                onupdate=con.onupdate, ondelete=con.ondelete,
280                                use_alter=con.use_alter))
281
282            if self.polymorphic and \
283               self.inheritance in ('single', 'multi') and \
284               self.children and not self.parent:
285                self.add_column(Column(self.polymorphic,
286                                       options.POLYMORPHIC_COL_TYPE))
287
288            if self.version_id_col:
289                if not isinstance(self.version_id_col, basestring):
290                    self.version_id_col = options.DEFAULT_VERSION_ID_COL_NAME
291                self.add_column(Column(self.version_id_col, Integer))
292
293            args = list(self.columns) + self.constraints + self.table_args
294        self.entity.table = Table(self.tablename, self.metadata,
295                                  *args, **kwargs)
296        if DEBUG:
297            print self.entity.table.repr2()
298
299    def setup_reltables(self):
300        self.call_builders('create_tables')
301
302    def after_table(self):
303        self.call_builders('after_table')
304
305    def setup_events(self):
306        def make_proxy_method(methods):
307            def proxy_method(self, mapper, connection, instance):
308                for func in methods:
309                    ret = func(instance)
310                    # I couldn't commit myself to force people to
311                    # systematicaly return EXT_CONTINUE in all their event
312                    # methods.
313                    # But not doing that diverge to how SQLAlchemy works.
314                    # I should try to convince Mike to do EXT_CONTINUE by
315                    # default, and stop processing as the special case.
316#                    if ret != EXT_CONTINUE:
317                    if ret is not None and ret != EXT_CONTINUE:
318                        return ret
319                return EXT_CONTINUE
320            return proxy_method
321
322        # create a list of callbacks for each event
323        methods = {}
324
325        all_methods = getmembers(self.entity,
326                                 lambda a: isinstance(a, types.MethodType))
327
328        for name, method in all_methods:
329            for event in getattr(method, '_elixir_events', []):
330                event_methods = methods.setdefault(event, [])
331                event_methods.append(method)
332
333        if not methods:
334            return
335
336        # transform that list into methods themselves
337        for event in methods:
338            methods[event] = make_proxy_method(methods[event])
339
340        # create a custom mapper extension class, tailored to our entity
341        ext = type('EventMapperExtension', (MapperExtension,), methods)()
342
343        # then, make sure that the entity's mapper has our mapper extension
344        self.add_mapper_extension(ext)
345
346    def before_mapper(self):
347        self.call_builders('before_mapper')
348
349    def _get_children(self):
350        children = self.children[:]
351        for child in self.children:
352            children.extend(child._descriptor._get_children())
353        return children
354
355    def translate_order_by(self, order_by):
356        if isinstance(order_by, basestring):
357            order_by = [order_by]
358
359        order = []
360        for colname in order_by:
361            col = self.get_column(colname.strip('-'))
362            if colname.startswith('-'):
363                col = desc(col)
364            order.append(col)
365        return order
366
367    def setup_mapper(self):
368        '''
369        Initializes and assign a mapper to the entity.
370        At this point the mapper will usually have no property as they are
371        added later.
372        '''
373        if self.entity.mapper:
374            return
375
376        # for now we don't support the "abstract" parent class in a concrete
377        # inheritance scenario as demonstrated in
378        # sqlalchemy/test/orm/inheritance/concrete.py
379        # this should be added along other
380        kwargs = {}
381        if self.order_by:
382            kwargs['order_by'] = self.translate_order_by(self.order_by)
383
384        if self.version_id_col:
385            kwargs['version_id_col'] = self.get_column(self.version_id_col)
386
387        if self.inheritance in ('single', 'concrete', 'multi'):
388            if self.parent and \
389               (self.inheritance != 'concrete' or self.polymorphic):
390                # non-polymorphic concrete doesn't need this
391                kwargs['inherits'] = self.parent.mapper
392
393            if self.inheritance == 'multi' and self.parent:
394                kwargs['inherit_condition'] = self.join_condition
395
396            if self.polymorphic:
397                if self.children:
398                    if self.inheritance == 'concrete':
399                        keys = [(self.identity, self.entity.table)]
400                        keys.extend([(child._descriptor.identity, child.table)
401                                     for child in self._get_children()])
402                        # Having the same alias name for an entity and one of
403                        # its child (which is a parent itself) shouldn't cause
404                        # any problem because the join shouldn't be used at
405                        # the same time. But in reality, some versions of SA
406                        # do misbehave on this. Since it doesn't hurt to have
407                        # different names anyway, here they go.
408                        pjoin = polymorphic_union(
409                                    dict(keys), self.polymorphic,
410                                    'pjoin_%s' % self.identity)
411
412                        kwargs['with_polymorphic'] = ('*', pjoin)
413                        kwargs['polymorphic_on'] = \
414                            getattr(pjoin.c, self.polymorphic)
415                    elif not self.parent:
416                        kwargs['polymorphic_on'] = \
417                            self.get_column(self.polymorphic)
418
419                if self.children or self.parent:
420                    kwargs['polymorphic_identity'] = self.identity
421
422                if self.parent and self.inheritance == 'concrete':
423                    kwargs['concrete'] = True
424
425        if self.parent and self.inheritance == 'single':
426            args = []
427        else:
428            args = [self.entity.table]
429
430        # let user-defined kwargs override Elixir-generated ones, though that's
431        # not very usefull since most of them expect Column instances.
432        kwargs.update(self.mapper_options)
433
434        #TODO: document this!
435        if 'primary_key' in kwargs:
436            cols = self.entity.table.c
437            kwargs['primary_key'] = [getattr(cols, colname) for
438                colname in kwargs['primary_key']]
439
440        # do the mapping
441        if self.session is None:
442            self.entity.mapper = mapper(self.entity, *args, **kwargs)
443        elif isinstance(self.session, ScopedSession):
444            session_mapper = session_mapper_factory(self.session)
445            self.entity.mapper = session_mapper(self.entity, *args, **kwargs)
446        else:
447            raise Exception("Failed to map entity '%s' with its table or "
448                            "selectable. You can only bind an Entity to a "
449                            "ScopedSession object or None for manual session "
450                            "management."
451                            % self.entity.__name__)
452
453    def after_mapper(self):
454        self.call_builders('after_mapper')
455
456    def setup_properties(self):
457        self.call_builders('create_properties')
458
459    def finalize(self):
460        self.call_builders('finalize')
461        self.entity._setup_done = True
462
463    #----------------
464    # helper methods
465
466    def call_builders(self, what):
467        for builder in self.builders:
468            if hasattr(builder, what):
469                getattr(builder, what)()
470
471    def add_column(self, col, check_duplicate=None):
472        '''when check_duplicate is None, the value of the allowcoloverride
473        option of the entity is used.
474        '''
475        if check_duplicate is None:
476            check_duplicate = not self.allowcoloverride
477
478        if col.key in self._columns:
479            if check_duplicate:
480                raise Exception("Column '%s' already exist in '%s' ! " %
481                                (col.key, self.entity.__name__))
482            else:
483                del self._columns[col.key]
484        self._columns.add(col)
485
486        if col.primary_key:
487            self.has_pk = True
488
489        table = self.entity.table
490        if table is not None:
491            if check_duplicate and col.key in table.columns.keys():
492                raise Exception("Column '%s' already exist in table '%s' ! " %
493                                (col.key, table.name))
494            table.append_column(col)
495            if DEBUG:
496                print "table.append_column(%s)" % col
497
498    def add_constraint(self, constraint):
499        self.constraints.append(constraint)
500
501        table = self.entity.table
502        if table is not None:
503            table.append_constraint(constraint)
504
505    def add_property(self, name, property, check_duplicate=True):
506        if check_duplicate and name in self.properties:
507            raise Exception("property '%s' already exist in '%s' ! " %
508                            (name, self.entity.__name__))
509        self.properties[name] = property
510
511#FIXME: something like this is needed to propagate the relationships from
512# parent entities to their children in a concrete inheritance scenario. But
513# this doesn't work because of the backref matching code. In most case
514# (test_concrete.py) it doesn't even happen at all.
515#        if self.children and self.inheritance == 'concrete':
516#            for child in self.children:
517#                child._descriptor.add_property(name, property)
518
519        mapper = self.entity.mapper
520        if mapper:
521            mapper.add_property(name, property)
522            if DEBUG:
523                print "mapper.add_property('%s', %s)" % (name, repr(property))
524
525    def add_mapper_extension(self, extension):
526        extensions = self.mapper_options.get('extension', [])
527        if not isinstance(extensions, list):
528            extensions = [extensions]
529        extensions.append(extension)
530        self.mapper_options['extension'] = extensions
531
532    def get_column(self, key, check_missing=True):
533        #TODO: this needs to work whether the table is already setup or not
534        #TODO: support SA table/autoloaded entity
535        try:
536            return self.columns[key]
537        except KeyError:
538            if check_missing:
539                raise Exception("No column named '%s' found in the table of "
540                                "the '%s' entity!"
541                                % (key, self.entity.__name__))
542
543    def get_inverse_relation(self, rel, check_reverse=True):
544        '''
545        Return the inverse relation of rel, if any, None otherwise.
546        '''
547
548        matching_rel = None
549        for other_rel in self.relationships:
550            if rel.is_inverse(other_rel):
551                if matching_rel is None:
552                    matching_rel = other_rel
553                else:
554                    raise Exception(
555                            "Several relations match as inverse of the '%s' "
556                            "relation in entity '%s'. You should specify "
557                            "inverse relations manually by using the inverse "
558                            "keyword."
559                            % (rel.name, rel.entity.__name__))
560        # When a matching inverse is found, we check that it has only
561        # one relation matching as its own inverse. We don't need the result
562        # of the method though. But we do need to be careful not to start an
563        # infinite recursive loop.
564        if matching_rel and check_reverse:
565            rel.entity._descriptor.get_inverse_relation(matching_rel, False)
566
567        return matching_rel
568
569    def find_relationship(self, name):
570        for rel in self.relationships:
571            if rel.name == name:
572                return rel
573        if self.parent:
574            return self.parent._descriptor.find_relationship(name)
575        else:
576            return None
577
578    #------------------------
579    # some useful properties
580
581    @property
582    def table_fullname(self):
583        '''
584        Complete name of the table for the related entity.
585        Includes the schema name if there is one specified.
586        '''
587        schema = self.table_options.get('schema', None)
588        if schema is not None:
589            return "%s.%s" % (schema, self.tablename)
590        else:
591            return self.tablename
592
593    @property
594    def columns(self):
595        if self.entity.table is not None:
596            return self.entity.table.columns
597        else:
598            #FIXME: depending on the type of inheritance, we should also
599            # return the parent entity's columns (for example for order_by
600            # using a column defined in the parent.
601            return self._columns
602
603    @property
604    def primary_keys(self):
605        """
606        Returns the list of primary key columns of the entity.
607
608        This property isn't valid before the "create_pk_cols" phase.
609        """
610        if self.autoload:
611            return [col for col in self.entity.table.primary_key.columns]
612        else:
613            if self.parent and self.inheritance == 'single':
614                return self.parent._descriptor.primary_keys
615            else:
616                return [col for col in self.columns if col.primary_key]
617
618    @property
619    def table(self):
620        if self.entity.table is not None:
621            return self.entity.table
622        else:
623            return FakeTable(self)
624
625    @property
626    def primary_key_properties(self):
627        """
628        Returns the list of (mapper) properties corresponding to the primary
629        key columns of the table of the entity.
630
631        This property caches its value, so it shouldn't be called before the
632        entity is fully set up.
633        """
634        if not hasattr(self, '_pk_props'):
635            col_to_prop = {}
636            mapper = self.entity.mapper
637            for prop in mapper.iterate_properties:
638                if isinstance(prop, ColumnProperty):
639                    for col in prop.columns:
640                        for col in col.proxy_set:
641                            col_to_prop[col] = prop
642            pk_cols = [c for c in mapper.mapped_table.c if c.primary_key]
643            self._pk_props = [col_to_prop[c] for c in pk_cols]
644        return self._pk_props
645
646class FakePK(object):
647    def __init__(self, descriptor):
648        self.descriptor = descriptor
649
650    @property
651    def columns(self):
652        return self.descriptor.primary_keys
653
654class FakeTable(object):
655    def __init__(self, descriptor):
656        self.descriptor = descriptor
657        self.primary_key = FakePK(descriptor)
658
659    @property
660    def columns(self):
661        return self.descriptor.columns
662
663    @property
664    def fullname(self):
665        '''
666        Complete name of the table for the related entity.
667        Includes the schema name if there is one specified.
668        '''
669        schema = self.descriptor.table_options.get('schema', None)
670        if schema is not None:
671            return "%s.%s" % (schema, self.descriptor.tablename)
672        else:
673            return self.descriptor.tablename
674
675
676def is_entity(cls):
677    """
678    Scan the bases classes of `cls` to see if any is an instance of
679    EntityMeta. If we don't find any, it means it is either an unrelated class
680    or an entity base class (like the 'Entity' class).
681    """
682    for base in cls.__bases__:
683        if isinstance(base, EntityMeta):
684            return True
685    return False
686
687
688# Note that we don't use inspect.getmembers because of
689# http://bugs.python.org/issue1785
690# See also http://elixir.ematia.de/trac/changeset/262
691def getmembers(object, predicate=None):
692    base_props = []
693    for key in dir(object):
694        try:
695            value = getattr(object, key)
696        except AttributeError:
697            continue
698        if not predicate or predicate(value):
699            base_props.append((key, value))
700    return base_props
701
702def is_abstract_entity(dict_or_cls):
703    if not isinstance(dict_or_cls, dict):
704        dict_or_cls = dict_or_cls.__dict__
705    for mutator, args, kwargs in dict_or_cls.get(MUTATORS, []):
706        if 'abstract' in kwargs:
707            return kwargs['abstract']
708
709    return False
710
711def instrument_class(cls):
712    """
713    Instrument a class as an Entity. This is usually done automatically through
714    the EntityMeta metaclass.
715    """
716    # Create the entity descriptor
717    #XXX: we might want to use a simplified descriptor for bases and abstract
718    # classes which only need "bases", "options_defaults" and "abstract"
719    desc = cls._descriptor = EntityDescriptor(cls)
720
721    # Process mutators
722    # We *do* want mutators to be processed for base/abstract classes
723    # (so that statements like using_options_defaults work).
724    process_mutators(cls)
725
726    # We do not want to do any more processing for base/abstract classes
727    # (Entity et al.).
728    if not is_entity(cls) or is_abstract_entity(cls):
729        return
730
731    cls.table = None
732    cls.mapper = None
733
734    # Copy the properties ('Property' instances) of the entity base class(es).
735    # We use getmembers (instead of __dict__) so that we also get the
736    # properties from the parents of the base class if any.
737    base_props = []
738    for base in cls.__bases__:
739        if isinstance(base, EntityMeta) and \
740           (not is_entity(base) or is_abstract_entity(base)):
741            base_props += [(name, deepcopy(attr)) for name, attr in
742                           getmembers(base, lambda a: isinstance(a, Property))]
743
744    # Process attributes (using the assignment syntax), looking for
745    # 'Property' instances and attaching them to this entity.
746    properties = [(name, attr) for name, attr in cls.__dict__.iteritems()
747                               if isinstance(attr, Property)]
748    sorted_props = sorted(base_props + properties,
749                          key=lambda i: i[1]._counter)
750    for name, prop in sorted_props:
751        prop.attach(cls, name)
752
753    # setup misc options here (like tablename etc.)
754    desc.setup_options()
755
756
757class EntityMeta(type):
758    """
759    Entity meta class.
760    You should only use it directly if you want to define your own base class
761    for your entities (ie you don't want to use the provided 'Entity' class).
762    """
763
764    def __init__(cls, name, bases, dict_):
765        instrument_class(cls)
766
767    def __setattr__(cls, key, value):
768        if isinstance(value, Property):
769            if hasattr(cls, '_setup_done'):
770                raise Exception('Cannot set attribute on a class after '
771                                'setup_all')
772            else:
773                value.attach(cls, key)
774        else:
775            type.__setattr__(cls, key, value)
776
777
778def setup_entities(entities):
779    '''Setup all entities in the list passed as argument'''
780
781    for entity in entities:
782        # delete all Elixir properties so that it doesn't interfere with
783        # SQLAlchemy. At this point they should have be converted to
784        # builders.
785        for name, attr in entity.__dict__.items():
786            if isinstance(attr, Property):
787                delattr(entity, name)
788
789    for method_name in (
790            'setup_autoload_table', 'create_pk_cols', 'setup_relkeys',
791            'before_table', 'setup_table', 'setup_reltables', 'after_table',
792            'setup_events',
793            'before_mapper', 'setup_mapper', 'after_mapper',
794            'setup_properties',
795            'finalize'):
796#        if DEBUG:
797#            print "=" * 40
798#            print method_name
799#            print "=" * 40
800        for entity in entities:
801#            print entity.__name__, "...",
802            if hasattr(entity, '_setup_done'):
803#                print "already done"
804                continue
805            method = getattr(entity._descriptor, method_name)
806            method()
807#            print "ok"
808
809
810def cleanup_entities(entities):
811    """
812    Try to revert back the list of entities passed as argument to the state
813    they had just before their setup phase.
814
815    As of now, this function is *not* functional in that it doesn't revert to
816    the exact same state the entities were before setup. For example, the
817    properties do not work yet as those would need to be regenerated (since the
818    columns they are based on are regenerated too -- and as such the
819    corresponding joins are not correct) but this doesn't happen because of
820    the way relationship setup is designed to be called only once (especially
821    the backref stuff in create_properties).
822    """
823    for entity in entities:
824        desc = entity._descriptor
825
826        if hasattr(entity, '_setup_done'):
827            del entity._setup_done
828
829        entity.table = None
830        entity.mapper = None
831
832        desc._pk_col_done = False
833        desc.has_pk = False
834        desc._columns = ColumnCollection()
835        desc.constraints = []
836        desc.properties = {}
837
838class EntityBase(object):
839    """
840    This class holds all methods of the "Entity" base class, but does not act
841    as a base class itself (it does not use the EntityMeta metaclass), but
842    rather as a parent class for Entity. This is meant so that people who want
843    to provide their own base class but don't want to loose or copy-paste all
844    the methods of Entity can do so by inheriting from EntityBase:
845
846    .. sourcecode:: python
847
848        class MyBase(EntityBase):
849            __metaclass__ = EntityMeta
850
851            def myCustomMethod(self):
852                # do something great
853    """
854
855    def __init__(self, **kwargs):
856        self.set(**kwargs)
857
858    def set(self, **kwargs):
859        for key, value in kwargs.iteritems():
860            setattr(self, key, value)
861
862    @classmethod
863    def update_or_create(cls, data, surrogate=True):
864        pk_props = cls._descriptor.primary_key_properties
865
866        # if all pk are present and not None
867        if not [1 for p in pk_props if data.get(p.key) is None]:
868            pk_tuple = tuple([data[prop.key] for prop in pk_props])
869            record = cls.query.get(pk_tuple)
870            if record is None:
871                if surrogate:
872                    raise Exception("cannot create surrogate with pk")
873                else:
874                    record = cls()
875        else:
876            if surrogate:
877                record = cls()
878            else:
879                raise Exception("cannot create non surrogate without pk")
880        record.from_dict(data)
881        return record
882
883    def from_dict(self, data):
884        """
885        Update a mapped class with data from a JSON-style nested dict/list
886        structure.
887        """
888        # surrogate can be guessed from autoincrement/sequence but I guess
889        # that's not 100% reliable, so we'll need an override
890
891        mapper = sqlalchemy.orm.object_mapper(self)
892
893        for key, value in data.iteritems():
894            if isinstance(value, dict):
895                dbvalue = getattr(self, key)
896                rel_class = mapper.get_property(key).mapper.class_
897                pk_props = rel_class._descriptor.primary_key_properties
898
899                # If the data doesn't contain any pk, and the relationship
900                # already has a value, update that record.
901                if not [1 for p in pk_props if p.key in data] and \
902                   dbvalue is not None:
903                    dbvalue.from_dict(value)
904                else:
905                    record = rel_class.update_or_create(value)
906                    setattr(self, key, record)
907            elif isinstance(value, list) and \
908                 value and isinstance(value[0], dict):
909
910                rel_class = mapper.get_property(key).mapper.class_
911                new_attr_value = []
912                for row in value:
913                    if not isinstance(row, dict):
914                        raise Exception(
915                                'Cannot send mixed (dict/non dict) data '
916                                'to list relationships in from_dict data.')
917                    record = rel_class.update_or_create(row)
918                    new_attr_value.append(record)
919                setattr(self, key, new_attr_value)
920            else:
921                setattr(self, key, value)
922
923    def to_dict(self, deep={}, exclude=[]):
924        """Generate a JSON-style nested dict/list structure from an object."""
925        col_prop_names = [p.key for p in self.mapper.iterate_properties \
926                                      if isinstance(p, ColumnProperty)]
927        data = dict([(name, getattr(self, name))
928                     for name in col_prop_names if name not in exclude])
929        for rname, rdeep in deep.iteritems():
930            dbdata = getattr(self, rname)
931            #FIXME: use attribute names (ie coltoprop) instead of column names
932            fks = self.mapper.get_property(rname).remote_side
933            exclude = [c.name for c in fks]
934            if dbdata is None:
935                data[rname] = None
936            elif isinstance(dbdata, list):
937                data[rname] = [o.to_dict(rdeep, exclude) for o in dbdata]
938            else:
939                data[rname] = dbdata.to_dict(rdeep, exclude)
940        return data
941
942    # session methods
943    def flush(self, *args, **kwargs):
944        return object_session(self).flush([self], *args, **kwargs)
945
946    def delete(self, *args, **kwargs):
947        return object_session(self).delete(self, *args, **kwargs)
948
949    def expire(self, *args, **kwargs):
950        return object_session(self).expire(self, *args, **kwargs)
951
952    def refresh(self, *args, **kwargs):
953        return object_session(self).refresh(self, *args, **kwargs)
954
955    def expunge(self, *args, **kwargs):
956        return object_session(self).expunge(self, *args, **kwargs)
957
958    # This bunch of session methods, along with all the query methods below
959    # only make sense when using a global/scoped/contextual session.
960    @property
961    def _global_session(self):
962        return self._descriptor.session.registry()
963
964    def merge(self, *args, **kwargs):
965        return self._global_session.merge(self, *args, **kwargs)
966
967    def save(self, *args, **kwargs):
968        return self._global_session.save(self, *args, **kwargs)
969
970    def update(self, *args, **kwargs):
971        return self._global_session.update(self, *args, **kwargs)
972
973    # only exist in SA < 0.5
974    # IMO, the replacement (session.add) doesn't sound good enough to be added
975    # here. For example: "o = Order(); o.add()" is not very telling. It's
976    # better to leave it as "session.add(o)"
977    def save_or_update(self, *args, **kwargs):
978        return self._global_session.save_or_update(self, *args, **kwargs)
979
980    # query methods
981    @classmethod
982    def get_by(cls, *args, **kwargs):
983        """
984        Returns the first instance of this class matching the given criteria.
985        This is equivalent to:
986        session.query(MyClass).filter_by(...).first()
987        """
988        return cls.query.filter_by(*args, **kwargs).first()
989
990    @classmethod
991    def get(cls, *args, **kwargs):
992        """
993        Return the instance of this class based on the given identifier,
994        or None if not found. This is equivalent to:
995        session.query(MyClass).get(...)
996        """
997        return cls.query.get(*args, **kwargs)
998
999
1000class Entity(EntityBase):
1001    '''
1002    The base class for all entities
1003
1004    All Elixir model objects should inherit from this class. Statements can
1005    appear within the body of the definition of an entity to define its
1006    fields, relationships, and other options.
1007
1008    Here is an example:
1009
1010    .. sourcecode:: python
1011
1012        class Person(Entity):
1013            name = Field(Unicode(128))
1014            birthdate = Field(DateTime, default=datetime.now)
1015
1016    Please note, that if you don't specify any primary keys, Elixir will
1017    automatically create one called ``id``.
1018
1019    For further information, please refer to the provided examples or
1020    tutorial.
1021    '''
1022    __metaclass__ = EntityMeta
1023
Note: See TracBrowser for help on using the browser.