root / elixir / trunk / elixir / entity.py @ 183

Revision 183, 19.5 kB (checked in by ged, 6 years ago)

added more statement hooks: before_table, before_mapper, after_mapper,
finalize

Line 
1'''
2Entity baseclass, metaclass and descriptor
3'''
4
5from sqlalchemy                     import Table, Integer, String, desc,\
6                                           ForeignKey
7from sqlalchemy.orm                 import deferred, Query, MapperExtension
8from sqlalchemy.ext.assignmapper    import assign_mapper
9from sqlalchemy.util                import OrderedDict
10import sqlalchemy
11from elixir.statements              import Statement
12from elixir.fields                  import Field
13from elixir.options                 import options_defaults
14
15try:
16    set
17except NameError:
18    from sets import Set as set
19
20import sys
21import warnings
22
23import elixir
24import inspect
25
26__pudge_all__ = ['Entity', 'EntityMeta']
27
28DEFAULT_AUTO_PRIMARYKEY_NAME = "id"
29DEFAULT_AUTO_PRIMARYKEY_TYPE = Integer
30DEFAULT_VERSION_ID_COL = "row_version"
31DEFAULT_POLYMORPHIC_COL_NAME = "row_type"
32DEFAULT_POLYMORPHIC_COL_SIZE = 20
33DEFAULT_POLYMORPHIC_COL_TYPE = String(DEFAULT_POLYMORPHIC_COL_SIZE)
34
35class EntityDescriptor(object):
36    '''
37    EntityDescriptor describes fields and options needed for table creation.
38    '''
39   
40    uninitialized_rels = set()
41    current = None
42   
43    def __init__(self, entity):
44        entity.table = None
45        entity.mapper = None
46
47        self.entity = entity
48        self.module = sys.modules[entity.__module__]
49
50        self.has_pk = False
51
52        self.parent = None
53        self.children = []
54
55        for base in entity.__bases__:
56            if issubclass(base, Entity) and base is not Entity:
57                if self.parent:
58                    raise Exception('%s entity inherits from several entities,'
59                                    ' and this is not supported.' 
60                                    % self.entity.__name__)
61                else:
62                    self.parent = base
63                    self.parent._descriptor.children.append(entity)
64
65        self.fields = OrderedDict()
66        #TODO Ordered
67        self.relationships = dict()
68        self.delayed_properties = dict()
69        self.constraints = list()
70
71        # set default value for options
72        self.order_by = None
73        self.table_args = list()
74        self.metadata = getattr(self.module, 'metadata', elixir.metadata)
75
76        for option in ('inheritance', 'polymorphic',
77                       'autoload', 'tablename', 'shortnames', 
78                       'auto_primarykey',
79                       'version_id_col'):
80            setattr(self, option, options_defaults[option])
81
82        for option_dict in ('mapper_options', 'table_options'):
83            setattr(self, option_dict, options_defaults[option_dict].copy())
84   
85    def setup_options(self):
86        '''
87        Setup any values that might depend on using_options. For example, the
88        tablename or the metadata.
89        '''
90        elixir.metadatas.add(self.metadata)
91
92        entity = self.entity
93        if self.inheritance == 'concrete' and self.polymorphic:
94            raise NotImplementedError("Polymorphic concrete inheritance is "
95                                      "not yet implemented.")
96
97        if self.parent:
98            if self.inheritance == 'single':
99                self.tablename = self.parent._descriptor.tablename
100
101        if not self.tablename:
102            if self.shortnames:
103                self.tablename = entity.__name__.lower()
104            else:
105                modulename = entity.__module__.replace('.', '_')
106                tablename = "%s_%s" % (modulename, entity.__name__)
107                self.tablename = tablename.lower()
108        elif callable(self.tablename):
109            self.tablename = self.tablename(entity)
110   
111    def setup_events(self):
112        # create a list of callbacks for each event
113        methods = {}
114        for name, func in inspect.getmembers(self.entity, inspect.ismethod):
115            if hasattr(func, '_elixir_events'):
116                for event in func._elixir_events:
117                    event_methods = methods.setdefault(event, [])
118                    event_methods.append(func)
119       
120        if not methods:
121            return
122       
123        # transform that list into methods themselves
124        for event in methods:
125            methods[event] = self.make_proxy_method(methods[event])
126       
127        # create a custom mapper extension class, tailored to our entity
128        ext = type('EventMapperExtension', (MapperExtension,), methods)()
129       
130        # then, make sure that the entity's mapper has our mapper extension
131        self.add_mapper_extension(ext)
132   
133    def make_proxy_method(self, methods):
134        def proxy_method(self, mapper, connection, instance):
135            for func in methods:
136                func(instance)
137        return proxy_method
138   
139    def translate_order_by(self, order_by):
140        if isinstance(order_by, basestring):
141            order_by = [order_by]
142       
143        order = list()
144        for field in order_by:
145            col = self.fields[field.strip('-')].column
146            if field.startswith('-'):
147                col = desc(col)
148            order.append(col)
149        return order
150
151    def setup_mapper(self):
152        '''
153        Initializes and assign an (empty!) mapper to the entity.
154        '''
155        if self.entity.mapper:
156            return
157       
158        session = getattr(self.module, 'session', elixir.objectstore)
159       
160        kwargs = self.mapper_options
161        if self.order_by:
162            kwargs['order_by'] = self.translate_order_by(self.order_by)
163       
164        if self.version_id_col:
165            kwargs['version_id_col'] = self.fields[self.version_id_col].column
166
167        if self.inheritance in ('single', 'concrete', 'multi'):
168            if self.parent and \
169               not (self.inheritance == 'concrete' and not self.polymorphic):
170                kwargs['inherits'] = self.parent.mapper
171
172            if self.polymorphic:
173                if self.children and not self.parent:
174                    kwargs['polymorphic_on'] = \
175                        self.fields[self.polymorphic].column
176                    if self.inheritance == 'multi':
177                        children = self._get_children()
178                        join = self.entity.table
179                        for child in children:
180                            join = join.outerjoin(child.table)
181                        kwargs['select_table'] = join
182                   
183                if self.children or self.parent:
184                    #TODO: make this customizable (both callable and string)
185                    #TODO: include module name
186                    kwargs['polymorphic_identity'] = \
187                        self.entity.__name__.lower()
188
189                if self.inheritance == 'concrete':
190                    kwargs['concrete'] = True
191
192        properties = dict()
193        for field in self.fields.itervalues():
194            if field.deferred:
195                group = None
196                if isinstance(field.deferred, basestring):
197                    group = field.deferred
198                properties[field.column.name] = deferred(field.column,
199                                                         group=group)
200
201        for name, prop in self.delayed_properties.iteritems():
202            properties[name] = self.evaluate_property(prop)
203        self.delayed_properties.clear()
204
205        if 'primary_key' in kwargs:
206            cols = self.entity.table.c
207            kwargs['primary_key'] = [getattr(cols, colname) for
208                colname in kwargs['primary_key']]
209
210        if self.parent and self.inheritance == 'single':
211            args = []
212        else:
213            args = [self.entity.table]
214
215        assign_mapper(session.context, self.entity, properties=properties, 
216                      *args, **kwargs)
217
218    def _get_children(self):
219        children = self.children[:]
220        for child in self.children:
221            children.extend(child._descriptor._get_children())
222        return children
223
224    def evaluate_property(self, prop):
225        if callable(prop):
226            return prop(self.entity.table.c)
227        else:
228            return prop
229
230    def add_property(self, name, prop):
231        if self.entity.mapper:
232            prop_value = self.evaluate_property(prop)
233            self.entity.mapper.add_property(name, prop_value)
234        else:
235            self.delayed_properties[name] = prop
236   
237    def add_mapper_extension(self, extension):
238        extensions = self.mapper_options.get('extension', [])
239        if not isinstance(extensions, list):
240            extensions = [extensions]
241        extensions.append(extension)
242        self.mapper_options['extension'] = extensions
243   
244    def setup_table(self):
245        '''
246        Create a SQLAlchemy table-object with all columns that have been
247        defined up to this point.
248        '''
249        if self.entity.table:
250            return
251       
252        if self.parent:
253            if self.inheritance == 'single':
254                # we know the parent is setup before the child
255                self.entity.table = self.parent.table 
256
257                # re-add the entity fields to the parent entity so that they
258                # are added to the parent's table (whether the parent's table
259                # is already setup or not).
260                for field in self.fields.itervalues():
261                    self.parent._descriptor.add_field(field)
262
263                return
264            elif self.inheritance == 'concrete':
265               # copy all fields from parent table
266               for field in self.parent._descriptor.fields.itervalues():
267                    self.add_field(field.copy())
268
269        if self.polymorphic and self.inheritance in ('single', 'multi') and \
270           self.children and not self.parent:
271            if not isinstance(self.polymorphic, basestring):
272                self.polymorphic = DEFAULT_POLYMORPHIC_COL_NAME
273               
274            self.add_field(Field(DEFAULT_POLYMORPHIC_COL_TYPE, 
275                                 colname=self.polymorphic))
276
277        if self.version_id_col:
278            if not isinstance(self.version_id_col, basestring):
279                self.version_id_col = DEFAULT_VERSION_ID_COL
280            self.add_field(Field(Integer, colname=self.version_id_col))
281
282        # create list of columns and constraints
283        args = [field.column for field in self.fields.itervalues()] \
284                    + self.constraints + self.table_args
285       
286        # specify options
287        kwargs = self.table_options
288
289        if self.autoload:
290            kwargs['autoload'] = True
291       
292        self.entity.table = Table(self.tablename, self.metadata, 
293                                  *args, **kwargs)
294
295
296    def create_pk_cols(self):
297        """
298        Create primary_key columns. That is, add columns from belongs_to
299        relationships marked as being a primary_key and then adds a primary
300        key to the table if it hasn't already got one and needs one.
301       
302        This method is "semi-recursive" in that it calls the create_keys
303        method on BelongsTo relationships and those in turn call create_pk_cols
304        on their target. It shouldn't be possible to have an infinite loop
305        since a loop of primary_keys is not a valid situation.
306        """
307        for rel in self.relationships.itervalues():
308            rel.create_keys(True)
309
310        if not self.autoload:
311            if self.parent and self.inheritance == 'multi':
312                # add foreign keys to the parent's primary key columns
313                parent_desc = self.parent._descriptor
314                for pk_col in parent_desc.primary_keys:
315                    colname = "%s_%s" % (self.parent.__name__.lower(),
316                                         pk_col.name)
317                    field = Field(pk_col.type, ForeignKey(pk_col), 
318                                  colname=colname, primary_key=True)
319                    self.add_field(field)
320            if not self.has_pk and self.auto_primarykey:
321                self.create_auto_primary_key()
322
323
324    def create_auto_primary_key(self):
325        '''
326        Creates a primary key
327        '''
328       
329        if isinstance(self.auto_primarykey, basestring):
330            colname = self.auto_primarykey
331        else:
332            colname = DEFAULT_AUTO_PRIMARYKEY_NAME
333       
334        self.add_field(Field(DEFAULT_AUTO_PRIMARYKEY_TYPE,
335                             colname=colname, primary_key=True))
336       
337    def add_field(self, field):
338        self.fields[field.colname] = field
339       
340        if field.primary_key:
341            self.has_pk = True
342
343        # we don't want to trigger setup_all too early
344        table = type.__getattribute__(self.entity, 'table')
345        if table:
346            table.append_column(field.column)
347   
348    def add_constraint(self, constraint):
349        self.constraints.append(constraint)
350       
351        table = self.entity.table
352        if table:
353            table.append_constraint(constraint)
354       
355    def get_inverse_relation(self, rel, reverse=False):
356        '''
357        Return the inverse relation of rel, if any, None otherwise.
358        '''
359
360        matching_rel = None
361        for other_rel in self.relationships.itervalues():
362            if other_rel.is_inverse(rel):
363                if matching_rel is None:
364                    matching_rel = other_rel
365                else:
366                    raise Exception(
367                            "Several relations match as inverse of the '%s' "
368                            "relation in entity '%s'. You should specify "
369                            "inverse relations manually by using the inverse "
370                            "keyword."
371                            % (rel.name, rel.entity.__name__))
372        # When a matching inverse is found, we check that it has only
373        # one relation matching as its own inverse. We don't need the result
374        # of the method though. But we do need to be careful not to start an
375        # infinite recursive loop.
376        if matching_rel and not reverse:
377            rel.entity._descriptor.get_inverse_relation(matching_rel, True)
378
379        return matching_rel
380
381    def primary_keys(self):
382        if self.autoload:
383            return [col for col in self.entity.table.primary_key.columns]
384        else:
385            return [field.column for field in self.fields.itervalues() if
386                    field.primary_key]
387    primary_keys = property(primary_keys)
388
389    def all_relationships(self):
390        if self.parent:
391            res = self.parent._descriptor.all_relationships
392        else:
393            res = dict()
394        res.update(self.relationships)
395        return res
396    all_relationships = property(all_relationships)
397
398
399class TriggerProxy(object):
400    def __init__(self, class_, attrname, setupfunc):
401        self.class_ = class_
402        self.attrname = attrname
403        self.setupfunc = setupfunc
404
405    def __getattr__(self, name):
406        self.setupfunc()
407        proxied_attr = getattr(self.class_, self.attrname)
408        return getattr(proxied_attr, name)
409
410    def __repr__(self):
411        proxied_attr = getattr(self.class_, self.attrname)
412        return "<TriggerProxy (%s)>" % (self.class_.__name__)
413
414class EntityMeta(type):
415    """
416    Entity meta class.
417    You should only use this if you want to define your own base class for your
418    entities (ie you don't want to use the provided 'Entity' class).
419    """
420    _ready = False
421    _entities = {}
422
423    def __init__(cls, name, bases, dict_):
424        # only process subclasses of Entity, not Entity itself
425        if bases[0] is object:
426            return
427
428        cid = cls._caller = id(sys._getframe(1))
429        caller_entities = EntityMeta._entities.setdefault(cid, {})
430        caller_entities[name] = cls
431
432        # create the entity descriptor
433        desc = cls._descriptor = EntityDescriptor(cls)
434
435        # process statements. Needed before the proxy for metadata
436        Statement.process(cls)
437
438        # setup misc options here (like tablename etc.)
439        desc.setup_options()
440
441        # create trigger proxies
442        # TODO: support entity_name... or maybe not. I'm not sure it makes
443        # sense in Elixir.
444        cls.setup_proxy()
445
446    def setup_proxy(cls, entity_name=None):
447        #TODO: move as much as possible of those "_private" values to the
448        # descriptor, so that we don't mess the initial class.
449        cls._class_key = sqlalchemy.orm.mapperlib.ClassKey(cls, entity_name)
450
451        tablename = cls._descriptor.tablename
452        schema = cls._descriptor.table_options.get('schema', None)
453        cls._table_key = sqlalchemy.schema._get_table_key(tablename, schema)
454
455        elixir._delayed_descriptors.append(cls._descriptor)
456       
457        mapper_proxy = TriggerProxy(cls, 'mapper', elixir.setup_all)
458        table_proxy = TriggerProxy(cls, 'table', elixir.setup_all)
459
460        sqlalchemy.orm.mapper_registry[cls._class_key] = mapper_proxy
461        md = cls._descriptor.metadata
462        md.tables[cls._table_key] = table_proxy
463
464        # We need to monkeypatch the metadata's table iterator method because
465        # otherwise it doesn't work if the setup is triggered by the
466        # metadata.create_all().
467        # This is because ManyToMany relationships add tables AFTER the list
468        # of tables that are going to be created is "computed"
469        # (metadata.tables.values()).
470        # see:
471        # - table_iterator method in MetaData class in sqlalchemy/schema.py
472        # - visit_metadata method in sqlalchemy/ansisql.py
473        original_table_iterator = md.table_iterator
474        if not hasattr(original_table_iterator, 
475                       '_non_elixir_patched_iterator'):
476            def table_iterator(*args, **kwargs):
477                elixir.setup_all()
478                return original_table_iterator(*args, **kwargs)
479            table_iterator.__doc__ = original_table_iterator.__doc__
480            table_iterator._non_elixir_patched_iterator = \
481                original_table_iterator
482            md.table_iterator = table_iterator
483
484        cls._ready = True
485
486    def __getattribute__(cls, name):
487        if type.__getattribute__(cls, "_ready"):
488            #TODO: we need to add all assign_mapper methods
489            if name in ('c', 'table', 'mapper'):
490                elixir.setup_all()
491        return type.__getattribute__(cls, name)
492
493    def __call__(cls, *args, **kwargs):
494        elixir.setup_all()
495        return type.__call__(cls, *args, **kwargs)
496
497    def q(cls):
498        return Query(cls, session=elixir.objectstore.session)
499    q = property(q)
500
501
502class Entity(object):
503    '''
504    The base class for all entities
505   
506    All Elixir model objects should inherit from this class. Statements can
507    appear within the body of the definition of an entity to define its
508    fields, relationships, and other options.
509   
510    Here is an example:
511
512    ::
513   
514        class Person(Entity):
515            has_field('name', Unicode(128))
516            has_field('birthdate', DateTime, default=datetime.now)
517   
518    Please note, that if you don't specify any primary keys, Elixir will
519    automatically create one called ``id``.
520   
521    For further information, please refer to the provided examples or
522    tutorial.
523    '''
524
525    __metaclass__ = EntityMeta
526
527    def __init__(self, **kwargs):
528        for key, value in kwargs.items():
529            setattr(self, key, value)
530
531    def get_by(cls, *args, **kwargs):
532#        warnings.warn("The get_by method on the class is deprecated."
533#                      "You should use cls.query.get_by", DeprecationWarning,
534#                      stacklevel=2)
535        return cls.q.get_by(*args, **kwargs)
536    get_by = classmethod(get_by)
537
538    def select(cls, *args, **kwargs):
539#        warnings.warn("The select method on the class is deprecated."
540#                      "You should use cls.query.select", DeprecationWarning,
541#                      stacklevel=2)
542        return cls.q.select(*args, **kwargs)
543    select = classmethod(select)
544
Note: See TracBrowser for help on using the browser.