root / elixir / trunk / elixir / entity.py @ 199

Revision 199, 24.5 kB (checked in by ged, 6 years ago)
  • Added test for the case when you refer to a remotely-defined class by its
    named after importing it into the local namespace.
  • Implemented a new syntax to declare fields and relationships, much closer to
    what is found in other Python ORM's. The with_fields syntax is now
    deprecated in favor of a that new syntax. The old statement-based (has_field et
    al.) syntax stays the default for now. This was done with help from a patch
    by Adam Gomaa.
  • Relationships to other classes can now also be defined using the classes
    themselves in addition to the class names. Obviously, this doesn't work for
    forward references.
Line 
1'''
2Entity baseclass, metaclass and descriptor
3'''
4
5import sqlalchemy
6
7from sqlalchemy                     import Table, Integer, String, desc,\
8                                           ForeignKey, and_
9from sqlalchemy.orm                 import deferred, Query, MapperExtension
10from sqlalchemy.ext.assignmapper    import assign_mapper
11from sqlalchemy.ext.sessioncontext  import SessionContext
12from sqlalchemy.util                import OrderedDict
13
14from elixir.statements              import Statement
15from elixir.fields                  import Field
16from elixir.options                 import options_defaults
17
18try:
19    set
20except NameError:
21    from sets import Set as set
22
23import sys
24import warnings
25
26import elixir
27import inspect
28
29__pudge_all__ = ['Entity', 'EntityMeta']
30
31DEFAULT_AUTO_PRIMARYKEY_NAME = "id"
32DEFAULT_AUTO_PRIMARYKEY_TYPE = Integer
33DEFAULT_VERSION_ID_COL = "row_version"
34DEFAULT_POLYMORPHIC_COL_NAME = "row_type"
35DEFAULT_POLYMORPHIC_COL_SIZE = 40
36DEFAULT_POLYMORPHIC_COL_TYPE = String(DEFAULT_POLYMORPHIC_COL_SIZE)
37
38class EntityDescriptor(object):
39    '''
40    EntityDescriptor describes fields and options needed for table creation.
41    '''
42   
43    def __init__(self, entity):
44        entity.table = None
45        entity.mapper = None
46
47        self.entity = entity
48        self.module = sys.modules[entity.__module__]
49
50        self.has_pk = False
51
52        self.parent = None
53        self.children = []
54
55        for base in entity.__bases__:
56            if issubclass(base, Entity) and base is not Entity:
57                if self.parent:
58                    raise Exception('%s entity inherits from several entities,'
59                                    ' and this is not supported.' 
60                                    % self.entity.__name__)
61                else:
62                    self.parent = base
63                    self.parent._descriptor.children.append(entity)
64
65        self.fields = OrderedDict()
66        self.relationships = list()
67        self.delayed_properties = dict()
68        self.constraints = list()
69
70        # set default value for options
71        self.order_by = None
72        self.table_args = list()
73        self.metadata = getattr(self.module, 'metadata', elixir.metadata)
74
75        for option in ('inheritance', 'polymorphic',
76                       'autoload', 'tablename', 'shortnames', 
77                       'auto_primarykey',
78                       'version_id_col'):
79            setattr(self, option, options_defaults[option])
80
81        for option_dict in ('mapper_options', 'table_options'):
82            setattr(self, option_dict, options_defaults[option_dict].copy())
83   
84    def setup_options(self):
85        '''
86        Setup any values that might depend on using_options. For example, the
87        tablename or the metadata.
88        '''
89        elixir.metadatas.add(self.metadata)
90
91        entity = self.entity
92        if self.inheritance == 'concrete' and self.polymorphic:
93            raise NotImplementedError("Polymorphic concrete inheritance is "
94                                      "not yet implemented.")
95
96        if self.parent:
97            if self.inheritance == 'single':
98                self.tablename = self.parent._descriptor.tablename
99
100        if not self.tablename:
101            if self.shortnames:
102                self.tablename = entity.__name__.lower()
103            else:
104                modulename = entity.__module__.replace('.', '_')
105                tablename = "%s_%s" % (modulename, entity.__name__)
106                self.tablename = tablename.lower()
107        elif callable(self.tablename):
108            self.tablename = self.tablename(entity)
109
110    def setup_autoload_table(self):
111        self.setup_table(True)
112
113    def create_pk_cols(self):
114        """
115        Create primary_key columns. That is, add columns from belongs_to
116        relationships marked as being a primary_key and then add a primary
117        key to the table if it hasn't already got one and needs one.
118       
119        This method is "semi-recursive" in that it calls the create_keys
120        method on BelongsTo relationships and those in turn call create_pk_cols
121        on their target. It shouldn't be possible to have an infinite loop
122        since a loop of primary_keys is not a valid situation.
123        """
124        for rel in self.relationships:
125            rel.create_keys(True)
126
127        if not self.autoload:
128            if self.parent and self.inheritance == 'multi':
129                # add foreign keys to the parent's primary key columns
130                parent_desc = self.parent._descriptor
131                for pk_col in parent_desc.primary_keys:
132                    colname = "%s_%s" % (self.parent.__name__.lower(),
133                                         pk_col.key)
134
135                    # it seems like SA ForeignKey is not happy being given a
136                    # real column object when said column is not yet attached
137                    # to a table
138                    pk_col_name = "%s.%s" % (parent_desc.tablename, pk_col.key)
139                    field = Field(pk_col.type, ForeignKey(pk_col_name), 
140                                  colname=colname, primary_key=True)
141                    self.add_field(field)
142            if not self.has_pk and self.auto_primarykey:
143                #FIXME: we'll need to do a special case for concrete
144                # inheritance too
145                if self.parent and self.inheritance == 'single':
146                    return
147
148                if isinstance(self.auto_primarykey, basestring):
149                    colname = self.auto_primarykey
150                else:
151                    colname = DEFAULT_AUTO_PRIMARYKEY_NAME
152               
153                self.add_field(Field(DEFAULT_AUTO_PRIMARYKEY_TYPE,
154                                     colname=colname, primary_key=True))
155
156    def setup_relkeys(self):
157        for rel in self.relationships:
158            rel.create_keys(False)
159
160    def before_table(self):
161        Statement.process(self.entity, 'before_table')
162       
163    def setup_table(self, only_autoloaded=False):
164        '''
165        Create a SQLAlchemy table-object with all columns that have been
166        defined up to this point.
167        '''
168        if self.entity.table:
169            return
170
171        if self.autoload != only_autoloaded:
172            return
173       
174        if self.parent:
175            if self.inheritance == 'single':
176                # we know the parent is setup before the child
177                self.entity.table = self.parent.table 
178
179                # re-add the entity fields to the parent entity so that they
180                # are added to the parent's table (whether the parent's table
181                # is already setup or not).
182                for field in self.fields.itervalues():
183                    self.parent._descriptor.add_field(field)
184                for constraint in self.constraints:
185                    self.parent._descriptor.add_constraint(constraint)
186                return
187            elif self.inheritance == 'concrete':
188               # copy all fields from parent table
189               for field in self.parent._descriptor.fields.itervalues():
190                    self.add_field(field.copy())
191               #FIXME: copy constraints. But those are not as simple to copy
192               #since the source column must be changed
193
194        if self.polymorphic and self.inheritance in ('single', 'multi') and \
195           self.children and not self.parent:
196            if not isinstance(self.polymorphic, basestring):
197                self.polymorphic = DEFAULT_POLYMORPHIC_COL_NAME
198               
199            self.add_field(Field(DEFAULT_POLYMORPHIC_COL_TYPE, 
200                                 colname=self.polymorphic))
201
202        if self.version_id_col:
203            if not isinstance(self.version_id_col, basestring):
204                self.version_id_col = DEFAULT_VERSION_ID_COL
205            self.add_field(Field(Integer, colname=self.version_id_col))
206
207        # create list of columns and constraints
208        args = [field.column for field in self.fields.itervalues()] \
209                    + self.constraints + self.table_args
210       
211        # specify options
212        kwargs = self.table_options
213
214        if self.autoload:
215            kwargs['autoload'] = True
216
217        self.entity.table = Table(self.tablename, self.metadata, 
218                                  *args, **kwargs)
219
220    def setup_reltables(self):
221        for rel in self.relationships:
222            rel.create_tables()
223
224    def after_table(self):
225        Statement.process(self.entity, 'after_table')
226
227    def setup_events(self):
228        def make_proxy_method(methods):
229            def proxy_method(self, mapper, connection, instance):
230                for func in methods:
231                    func(instance)
232            return proxy_method
233
234        # create a list of callbacks for each event
235        methods = {}
236        for name, func in inspect.getmembers(self.entity, inspect.ismethod):
237            if hasattr(func, '_elixir_events'):
238                for event in func._elixir_events:
239                    event_methods = methods.setdefault(event, [])
240                    event_methods.append(func)
241       
242        if not methods:
243            return
244       
245        # transform that list into methods themselves
246        for event in methods:
247            methods[event] = make_proxy_method(methods[event])
248       
249        # create a custom mapper extension class, tailored to our entity
250        ext = type('EventMapperExtension', (MapperExtension,), methods)()
251       
252        # then, make sure that the entity's mapper has our mapper extension
253        self.add_mapper_extension(ext)
254
255    def before_mapper(self):
256        Statement.process(self.entity, 'before_mapper')
257
258    def _get_children(self):
259        children = self.children[:]
260        for child in self.children:
261            children.extend(child._descriptor._get_children())
262        return children
263
264    def evaluate_property(self, prop):
265        if callable(prop):
266            return prop(self.entity.table.c)
267        else:
268            return prop
269
270    def translate_order_by(self, order_by):
271        if isinstance(order_by, basestring):
272            order_by = [order_by]
273       
274        order = list()
275        for field in order_by:
276            col = self.fields[field.strip('-')].column
277            if field.startswith('-'):
278                col = desc(col)
279            order.append(col)
280        return order
281
282    def setup_mapper(self):
283        '''
284        Initializes and assign an (empty!) mapper to the entity.
285        '''
286        if self.entity.mapper:
287            return
288       
289        # look for a 'session' attribute assigned to the entity
290        # (or entity's base class)
291        session = getattr(self, 'session', None)
292        if session is None:
293            session = getattr(self.module, 'session', elixir.objectstore)
294        if not isinstance(session, Objectstore):
295            session = Objectstore(session)
296           
297        self.objectstore = session
298       
299        kwargs = self.mapper_options
300        if self.order_by:
301            kwargs['order_by'] = self.translate_order_by(self.order_by)
302       
303        if self.version_id_col:
304            kwargs['version_id_col'] = self.fields[self.version_id_col].column
305
306        if self.inheritance in ('single', 'concrete', 'multi'):
307            if self.parent and \
308               not (self.inheritance == 'concrete' and not self.polymorphic):
309                kwargs['inherits'] = self.parent.mapper
310
311            if self.inheritance == 'multi' and self.parent:
312                col_pairs = zip(self.primary_keys,
313                                self.parent._descriptor.primary_keys)
314                kwargs['inherit_condition'] = \
315                    and_(*[pc == c for c,pc in col_pairs])
316
317            if self.polymorphic:
318                if self.children and not self.parent:
319                    kwargs['polymorphic_on'] = \
320                        self.fields[self.polymorphic].column
321                    #TODO: this is an optimization, and it breaks the multi
322                    # table polymorphic inheritance test with a relation.
323                    # So I turn it off for now. We might want to provide an
324                    # option to turn it on.
325#                    if self.inheritance == 'multi':
326#                        children = self._get_children()
327#                        join = self.entity.table
328#                        for child in children:
329#                            join = join.outerjoin(child.table)
330#                        kwargs['select_table'] = join
331                   
332                if self.children or self.parent:
333                    #TODO: make this customizable (both callable and string)
334                    #TODO: include module name
335                    kwargs['polymorphic_identity'] = \
336                        self.entity.__name__.lower()
337
338                if self.inheritance == 'concrete':
339                    kwargs['concrete'] = True
340
341        properties = dict()
342        for field in self.fields.itervalues():
343            if field.deferred:
344                group = None
345                if isinstance(field.deferred, basestring):
346                    group = field.deferred
347                properties[field.column.name] = deferred(field.column,
348                                                         group=group)
349
350        for name, prop in self.delayed_properties.iteritems():
351            properties[name] = self.evaluate_property(prop)
352        self.delayed_properties.clear()
353
354        if 'primary_key' in kwargs:
355            cols = self.entity.table.c
356            kwargs['primary_key'] = [getattr(cols, colname) for
357                colname in kwargs['primary_key']]
358
359        if self.parent and self.inheritance == 'single':
360            args = []
361        else:
362            args = [self.entity.table]
363
364        self.objectstore.mapper(self.entity, properties=properties, 
365                                *args, **kwargs)
366
367    def after_mapper(self):
368        Statement.process(self.entity, 'after_mapper')
369
370    def setup_properties(self):
371        for rel in self.relationships:
372            rel.create_properties()
373
374    def finalize(self):
375        Statement.process(self.entity, 'finalize')
376
377    #--------------
378
379    def add_field(self, field):
380#        if field.colname in self.fields:
381#            print "duplicate field", field.colname
382        self.fields[field.colname] = field
383       
384        if field.primary_key:
385            self.has_pk = True
386
387        # we don't want to trigger setup_all too early
388        table = type.__getattribute__(self.entity, 'table')
389        if table:
390#TODO: we might want to check for that case
391#            if field.colname in table.columns.keys():
392            table.append_column(field.column)
393   
394    def add_constraint(self, constraint):
395        self.constraints.append(constraint)
396       
397        table = self.entity.table
398        if table:
399            table.append_constraint(constraint)
400       
401    def add_property(self, name, prop):
402        if self.entity.mapper:
403            prop_value = self.evaluate_property(prop)
404            self.entity.mapper.add_property(name, prop_value)
405        else:
406            self.delayed_properties[name] = prop
407   
408    def add_mapper_extension(self, extension):
409        extensions = self.mapper_options.get('extension', [])
410        if not isinstance(extensions, list):
411            extensions = [extensions]
412        extensions.append(extension)
413        self.mapper_options['extension'] = extensions
414
415    def get_inverse_relation(self, rel, reverse=False):
416        '''
417        Return the inverse relation of rel, if any, None otherwise.
418        '''
419
420        matching_rel = None
421        for other_rel in self.relationships:
422            if other_rel.is_inverse(rel):
423                if matching_rel is None:
424                    matching_rel = other_rel
425                else:
426                    raise Exception(
427                            "Several relations match as inverse of the '%s' "
428                            "relation in entity '%s'. You should specify "
429                            "inverse relations manually by using the inverse "
430                            "keyword."
431                            % (rel.name, rel.entity.__name__))
432        # When a matching inverse is found, we check that it has only
433        # one relation matching as its own inverse. We don't need the result
434        # of the method though. But we do need to be careful not to start an
435        # infinite recursive loop.
436        if matching_rel and not reverse:
437            rel.entity._descriptor.get_inverse_relation(matching_rel, True)
438
439        return matching_rel
440
441    def find_relationship(self, name):
442        for rel in self.relationships:
443            if rel.name == name:
444                return rel
445        if self.parent:
446            return self.parent.find_relationship(name)
447        else:
448            return None
449
450    def primary_keys(self):
451        if self.autoload:
452            return [col for col in self.entity.table.primary_key.columns]
453        else:
454            if self.parent and self.inheritance == 'single':
455                return self.parent._descriptor.primary_keys
456            else:
457                return [field.column for field in self.fields.itervalues() if
458                        field.primary_key]
459    primary_keys = property(primary_keys)
460
461
462class TriggerProxy(object):
463    """A class that serves as a "trigger" ; accessing its attributes runs
464    the function that is set at initialization.
465
466    Primarily used for setup_all().
467
468    Note that the `setupfunc` parameter is called on each access of
469    the attribute.
470
471    """
472    def __init__(self, class_, attrname, setupfunc):
473        self.class_ = class_
474        self.attrname = attrname
475        self.setupfunc = setupfunc
476
477    def __getattr__(self, name):
478        self.setupfunc()
479        proxied_attr = getattr(self.class_, self.attrname)
480        return getattr(proxied_attr, name)
481
482    def __repr__(self):
483        proxied_attr = getattr(self.class_, self.attrname)
484        return "<TriggerProxy (%s)>" % (self.class_.__name__)
485
486def _is_entity(class_):
487    return isinstance(class_, EntityMeta)
488
489class EntityMeta(type):
490    """
491    Entity meta class.
492    You should only use this if you want to define your own base class for your
493    entities (ie you don't want to use the provided 'Entity' class).
494    """
495    _ready = False
496    _entities = {}
497
498    def __init__(cls, name, bases, dict_):
499        # only process subclasses of Entity, not Entity itself
500        if bases[0] is object:
501            return
502
503        # build a dict of entities for each frame where there are entities
504        # defined
505        caller_frame = sys._getframe(1)
506        cid = cls._caller = id(caller_frame)
507        caller_entities = EntityMeta._entities.setdefault(cid, {})
508        caller_entities[name] = cls
509
510        # Append all entities which are currently visible by the entity. This
511        # will find more entities only if some of them where imported from
512        # another module.
513        for entity in [e for e in caller_frame.f_locals.values() 
514                         if _is_entity(e)]:
515            caller_entities[entity.__name__] = entity
516
517        # create the entity descriptor
518        desc = cls._descriptor = EntityDescriptor(cls)
519
520        # process statements. Needed before the proxy for metadata
521        Statement.process(cls)
522
523        # Process attributes, for the assignment syntax.
524        cls._process_attrs(dict_)
525
526        # setup misc options here (like tablename etc.)
527        desc.setup_options()
528
529        # create trigger proxies
530        # TODO: support entity_name... or maybe not. I'm not sure it makes
531        # sense in Elixir.
532        cls._setup_proxy()
533
534    def _setup_proxy(cls, entity_name=None):
535        #TODO: move as much as possible of those "_private" values to the
536        # descriptor, so that we don't mess the initial class.
537        cls._class_key = sqlalchemy.orm.mapperlib.ClassKey(cls, entity_name)
538
539        tablename = cls._descriptor.tablename
540        schema = cls._descriptor.table_options.get('schema', None)
541        cls._table_key = sqlalchemy.schema._get_table_key(tablename, schema)
542
543        elixir._delayed_entities.append(cls)
544       
545        mapper_proxy = TriggerProxy(cls, 'mapper', elixir.setup_all)
546        table_proxy = TriggerProxy(cls, 'table', elixir.setup_all)
547
548        sqlalchemy.orm.mapper_registry[cls._class_key] = mapper_proxy
549        md = cls._descriptor.metadata
550        md.tables[cls._table_key] = table_proxy
551
552        # We need to monkeypatch the metadata's table iterator method because
553        # otherwise it doesn't work if the setup is triggered by the
554        # metadata.create_all().
555        # This is because ManyToMany relationships add tables AFTER the list
556        # of tables that are going to be created is "computed"
557        # (metadata.tables.values()).
558        # see:
559        # - table_iterator method in MetaData class in sqlalchemy/schema.py
560        # - visit_metadata method in sqlalchemy/ansisql.py
561        original_table_iterator = md.table_iterator
562        if not hasattr(original_table_iterator, 
563                       '_non_elixir_patched_iterator'):
564            def table_iterator(*args, **kwargs):
565                elixir.setup_all()
566                return original_table_iterator(*args, **kwargs)
567            table_iterator.__doc__ = original_table_iterator.__doc__
568            table_iterator._non_elixir_patched_iterator = \
569                original_table_iterator
570            md.table_iterator = table_iterator
571
572        cls._ready = True
573
574    def _process_attrs(cls, attr_dict):
575        """Process class attributes, looking for Elixir `Field`s or
576        `Relationship`.
577        """
578
579        for name, attr in attr_dict.iteritems():
580            # Check if it's Elixir related.
581            if isinstance(attr, Field):
582                # If no colname was defined (through the 'colname' kwarg), set
583                # it to the name of the attr.
584                if attr.colname is None:
585                    attr.colname = name
586                cls._descriptor.add_field(attr)
587            elif isinstance(attr, elixir.relationships.Relationship):
588                attr.name = name 
589                attr.entity = cls
590                cls._descriptor.relationships.append(attr)
591            else:
592                # Not an Elixir field, let it be.
593                pass
594        return
595
596    def __getattribute__(cls, name):
597        if type.__getattribute__(cls, "_ready"):
598            #TODO: we need to add all assign_mapper methods
599            if name in ('c', 'table', 'mapper'):
600                elixir.setup_all()
601        return type.__getattribute__(cls, name)
602
603    def __call__(cls, *args, **kwargs):
604        elixir.setup_all()
605        return type.__call__(cls, *args, **kwargs)
606
607    def q(cls):
608        return Query(cls, session=cls._descriptor.objectstore.session)
609    q = property(q)
610
611
612class Entity(object):
613    '''
614    The base class for all entities
615   
616    All Elixir model objects should inherit from this class. Statements can
617    appear within the body of the definition of an entity to define its
618    fields, relationships, and other options.
619   
620    Here is an example:
621
622    ::
623   
624        class Person(Entity):
625            has_field('name', Unicode(128))
626            has_field('birthdate', DateTime, default=datetime.now)
627   
628    Please note, that if you don't specify any primary keys, Elixir will
629    automatically create one called ``id``.
630   
631    For further information, please refer to the provided examples or
632    tutorial.
633    '''
634    __metaclass__ = EntityMeta
635
636    def __init__(self, **kwargs):
637        for key, value in kwargs.items():
638            setattr(self, key, value)
639
640    def get_by(cls, *args, **kwargs):
641#        warnings.warn("The get_by method on the class is deprecated."
642#                      "You should use cls.query.get_by", DeprecationWarning,
643#                      stacklevel=2)
644        return cls.q.get_by(*args, **kwargs)
645    get_by = classmethod(get_by)
646
647    def select(cls, *args, **kwargs):
648#        warnings.warn("The select method on the class is deprecated."
649#                      "You should use cls.query.select", DeprecationWarning,
650#                      stacklevel=2)
651        return cls.q.select(*args, **kwargs)
652    select = classmethod(select)
653
654
655class Objectstore(object):
656    """a wrapper for a SQLAlchemy session-making object, such as
657    SessionContext or ScopedSession.
658   
659    Uses the ``registry`` attribute present on both objects
660    (versions 0.3 and 0.4) in order to return the current
661    contextual session.
662    """
663   
664    def __init__(self, ctx):
665        self.context = ctx
666        self.is_ctx = isinstance(ctx, SessionContext)
667
668    def __getattr__(self, name):
669        return getattr(self.context.registry(), name)
670   
671    def mapper(self, cls, *args, **kwargs):
672        if self.is_ctx:
673            assign_mapper(self.context, cls, *args, **kwargs)
674        else:
675            cls.mapper = self.context.mapper(cls, *args, **kwargs)
676       
677    session = property(lambda s:s.context.registry())
Note: See TracBrowser for help on using the browser.