Changeset 360

Show
Ignore:
Timestamp:
07/11/08 15:58:29 (6 years ago)
Author:
ged
Message:
  • cleanup "copy properties from base class" implementation (closes #15)
  • rewrite from_dict to use property names instead of column names (fixes #47,
    test from Jason R. Coombs)
Location:
elixir/trunk
Files:
2 modified

Legend:

Unmodified
Added
Removed
  • elixir/trunk/elixir/entity.py

    r356 r360  
    1515                           ForeignKeyConstraint 
    1616from sqlalchemy.orm import Query, MapperExtension, mapper, object_session, \ 
    17                            EXT_CONTINUE, polymorphic_union, ScopedSession 
     17                           EXT_CONTINUE, polymorphic_union, ScopedSession, \ 
     18                           ColumnProperty 
    1819 
    1920import elixir 
     
    6263        self.constraints = list() 
    6364 
    64         # properties waiting for a mapper to exist 
     65        # properties (it is only useful for checking dupe properties at the 
     66        # moment, and when adding properties before the mapper is created, 
     67        # which shouldn't happen). 
    6568        self.properties = dict() 
    6669 
     
    326329    def setup_mapper(self): 
    327330        ''' 
    328         Initializes and assign an (empty!) mapper to the entity. 
     331        Initializes and assign a mapper to the entity. 
     332        At this point the mapper will usually have no property as they are 
     333        added later. 
    329334        ''' 
    330335        if self.entity.mapper: 
     
    400405 
    401406        # do the mapping 
    402         kwargs['properties'] = self.properties 
    403407        if self.session is None: 
    404408            self.entity.mapper = mapper(self.entity, *args, **kwargs) 
     
    491495 
    492496    def get_column(self, key, check_missing=True): 
    493         "need to support both the case where the table is already setup or not" 
     497        #TODO: this needs to work whether the table is already setup or not 
    494498        #TODO: support SA table/autoloaded entity 
    495499        for col in self.columns: 
     
    536540            return None 
    537541 
     542    #------------------------ 
     543    # some useful properties 
     544 
    538545    def table_fullname(self): 
    539546        ''' 
     
    576583    primary_keys = property(primary_keys) 
    577584 
     585    def primary_key_properties(self): 
     586        """ 
     587        Returns the list of (mapper) properties corresponding to the primary 
     588        key columns of the table of the entity. 
     589 
     590        This property caches its value, so it shouldn't be called before the 
     591        entity is fully set up. 
     592        """ 
     593        if not hasattr(self, '_pk_props'): 
     594            col_to_prop = {} 
     595            mapper = self.entity.mapper 
     596            for prop in mapper.iterate_properties: 
     597                if isinstance(prop, ColumnProperty): 
     598                    for col in prop.columns: 
     599                        for col in col.proxy_set: 
     600                            col_to_prop[col] = prop 
     601            pk_cols = [c for c in mapper.mapped_table.c if c.primary_key] 
     602            self._pk_props = [col_to_prop[c] for c in pk_cols] 
     603        return self._pk_props 
     604    primary_key_properties = property(primary_key_properties) 
    578605 
    579606class TriggerProxy(object): 
     
    644671        desc = cls._descriptor = EntityDescriptor(cls) 
    645672 
    646         # Process attributes (using the assignment syntax), looking for 
    647         # 'Property' instances and attaching them to this entity. 
    648         properties = [(name, attr) for name, attr in dict_.iteritems() 
    649                                    if isinstance(attr, Property)] 
    650         sorted_props = sorted(properties, key=lambda i: i[1]._counter) 
    651         for name, prop in sorted_props: 
    652             prop.attach(cls, name) 
    653  
     673        # Determine whether this entity is a *direct* subclass of its base 
     674        # entity 
    654675        entity_base = None 
    655         for base in bases: 
     676        for base in cls.__bases__: 
    656677            if isinstance(base, EntityMeta): 
    657678                if not is_entity(base): 
    658679                    entity_base = base 
     680 
    659681        if entity_base: 
    660             # Process attributes (using the assignment syntax), looking for 
    661             # 'Property' instances and attaching them to this entity. 
     682            # If so, copy the base entity properties ('Property' instances). 
     683            # We use inspect.getmembers (instead of __dict__) so that we also 
     684            # get the properties from the parents of the base_class if any. 
    662685            base_props = inspect.getmembers(entity_base, 
    663686                                            lambda a: isinstance(a, Property)) 
    664             local_props = [(name, copy(attr)) for name, attr in base_props] 
    665             sorted_props = sorted(local_props, key=lambda i: i[1]._counter) 
    666             for name, prop in sorted_props: 
    667                 prop.attach(cls, name) 
     687            base_props = [(name, copy(attr)) for name, attr in base_props] 
     688        else: 
     689            base_props = [] 
     690 
     691        # Process attributes (using the assignment syntax), looking for 
     692        # 'Property' instances and attaching them to this entity. 
     693        properties = [(name, attr) for name, attr in cls.__dict__.iteritems() 
     694                                   if isinstance(attr, Property)] 
     695        sorted_props = sorted(base_props + properties, 
     696                              key=lambda i: i[1]._counter) 
     697        for name, prop in sorted_props: 
     698            prop.attach(cls, name) 
    668699 
    669700        # Process mutators. Needed before _install_autosetup_triggers so that 
     
    730761 
    731762    #TODO: we might want to add all columns that will be available as 
    732     #attributes on the class itself (in SA 0.4). This would be a pretty 
     763    #attributes on the class itself (in SA 0.4+). This is a pretty 
    733764    #rare usecase, as people will normally hit the query attribute before the 
    734     #column attributes, but still... 
     765    #column attributes, but I've seen people hitting this problem... 
    735766    for name in ('c', 'table', 'mapper', 'query'): 
    736767        setattr(cls, name, TriggerAttribute(name)) 
     
    842873 
    843874    def __init__(self, **kwargs): 
    844         for key, value in kwargs.items(): 
    845             setattr(self, key, value) 
     875        self.from_dict(kwargs) 
    846876 
    847877    def set(self, **kwargs): 
    848878        self.from_dict(kwargs) 
     879 
     880    def update_or_create(cls, data, surrogate=True): 
     881        pk_props = cls._descriptor.primary_key_properties 
     882 
     883        # if all pk are present and not None 
     884        if not [1 for p in pk_props if data.get(p.key) is None]: 
     885            pk_tuple = tuple([data[prop.key] for prop in pk_props]) 
     886            record = cls.query.get(pk_tuple) 
     887            if record is None: 
     888                if surrogate: 
     889                    raise Exception("cannot create surrogate with pk") 
     890                else: 
     891                    record = cls() 
     892        else: 
     893            if surrogate: 
     894                record = cls() 
     895            else: 
     896                raise Exception("cannot create non surrogate without pk") 
     897        record.from_dict(data) 
     898        return record 
     899    update_or_create = classmethod(update_or_create) 
    849900 
    850901    def from_dict(self, data): 
     
    853904        structure. 
    854905        """ 
     906        # surrogate can be guessed from autoincrement/sequence but I guess 
     907        # that's not 100% reliable, so we'll need an override 
     908 
    855909        mapper = sqlalchemy.orm.object_mapper(self) 
    856         session = sqlalchemy.orm.object_session(self) 
    857  
    858         for col in mapper.mapped_table.c: 
    859             if not col.primary_key and data.has_key(col.name): 
    860                 setattr(self, col.name, data[col.name]) 
    861  
    862         for rel in mapper.iterate_properties: 
    863             rname = rel.key 
    864             if isinstance(rel, sqlalchemy.orm.properties.PropertyLoader) \ 
    865                     and data.has_key(rname): 
    866                 dbdata = getattr(self, rname) 
    867                 if rel.uselist: 
    868                     pkey = [c for c in rel.table.columns if c.primary_key] 
    869  
    870                     # Build a lookup dict: {(pk1, pk2): value} 
    871                     lookup = dict([ 
    872                         (tuple([getattr(o, c.name) for c in pkey]), o) 
    873                         for o in dbdata]) 
    874                     for row in data[rname]: 
    875                         # If any primary key columns are missing or None, 
    876                         # create a new object 
    877                         if [1 for c in pkey if not row.get(c.name)]: 
    878                             subobj = rel.mapper.class_() 
    879                             dbdata.append(subobj) 
    880                         else: 
    881                             key = tuple([row[c.name] for c in pkey]) 
    882                             subobj = lookup.pop(key, None) 
    883  
    884                             # If the row isn't found, we must fail the request 
    885                             # in a web scenario, this could be a parameter 
    886                             # tampering attack 
    887                             if not subobj: 
    888                                 raise sqlalchemy.exceptions.ArgumentError( 
    889                                         '%s row not found in database: %s' \ 
    890                                         % (rname, repr(row))) 
    891                         subobj.from_dict(row) 
    892  
    893                     # Make sure the object list attribute doesn't contain any 
    894                     # old value (which are not present in the new data). 
    895                     for delobj in lookup.itervalues(): 
    896                         dbdata.remove(delobj) 
    897                         session.delete(delobj) 
     910 
     911        for key, value in data.iteritems(): 
     912            if isinstance(value, dict): 
     913                dbvalue = getattr(self, key) 
     914                rel_class = mapper.get_property(key).mapper.class_ 
     915                pk_props = rel_class._descriptor.primary_key_properties 
     916 
     917                # If the data doesn't contain any pk, and the relationship 
     918                # already has a value, update that record. 
     919                if not [1 for p in pk_props if p.key in data] and \ 
     920                   dbvalue is not None: 
     921                    dbvalue.from_dict(value) 
    898922                else: 
    899                     if data[rname] is None: 
    900                         setattr(self, rname, None) 
    901                     else: 
    902                         if not dbdata: 
    903                             dbdata = rel.mapper.class_() 
    904                             setattr(self, rname, dbdata) 
    905                         dbdata.from_dict(data[rname]) 
     923                    record = rel_class.update_or_create(value) 
     924                    setattr(self, key, record) 
     925            elif isinstance(value, list) and \ 
     926                 value and isinstance(value[0], dict): 
     927 
     928                rel_class = mapper.get_property(key).mapper.class_ 
     929                new_attr_value = [] 
     930                for row in value: 
     931                    if not isinstance(row, dict): 
     932                        raise Exception( 
     933                                'Cannot send mixed (dict/non dict) data ' 
     934                                'to list relationships in from_dict data.') 
     935                    record = rel_class.update_or_create(row) 
     936                    new_attr_value.append(record) 
     937                setattr(self, key, new_attr_value) 
     938            else: 
     939                setattr(self, key, value) 
    906940 
    907941    def to_dict(self, deep={}, exclude=[]): 
    908942        """Generate a JSON-style nested dict/list structure from an object.""" 
    909         columns = [] 
    910         for table in self.mapper.tables: 
    911             for col in table.c: 
    912                 columns.append(col) 
    913  
    914         data = dict([(col.name, getattr(self, col.name)) 
    915                      for col in columns if col.name not in exclude]) 
     943        col_prop_names = [p.key for p in self.mapper.iterate_properties \ 
     944                                      if isinstance(p, ColumnProperty)] 
     945        data = dict([(name, getattr(self, name)) 
     946                     for name in col_prop_names if name not in exclude]) 
    916947        for rname, rdeep in deep.iteritems(): 
    917948            dbdata = getattr(self, rname) 
     949            #FIXME: use attribute names (ie coltoprop) instead of column names 
    918950            fks = self.mapper.get_property(rname).remote_side 
    919951            exclude = [c.name for c in fks] 
  • elixir/trunk/tests/test_dict.py

    r349 r360  
    33""" 
    44 
    5 import sqlalchemy as sa, elixir as el 
     5import sqlalchemy as sa 
     6from elixir import * 
    67 
    78def setup(): 
    8     el.metadata.bind = 'sqlite:///' 
     9    metadata.bind = 'sqlite:///' 
     10 
    911    global Table1, Table2, Table3 
    10     class Table1(el.Entity): 
    11         name = el.Field(el.String(30)) 
    12         tbl2s = el.OneToMany('Table2') 
    13         tbl3 = el.OneToOne('Table3') 
    14     class Table2(el.Entity): 
    15         name = el.Field(el.String(30)) 
    16         tbl1 = el.ManyToOne(Table1) 
    17     class Table3(el.Entity): 
    18         name = el.Field(el.String(30)) 
    19         tbl1 = el.ManyToOne(Table1) 
    20     el.setup_all() 
    21     el.create_all() 
     12    class Table1(Entity): 
     13        t1id = Field(Integer, primary_key=True) 
     14        name = Field(String(30)) 
     15        tbl2s = OneToMany('Table2') 
     16        tbl3 = OneToOne('Table3') 
    2217 
    23 def test_set_attr(): 
    24     t1 = Table1() 
    25     t1.from_dict(dict(name='test1')) 
    26     assert t1.name == 'test1' 
     18    class Table2(Entity): 
     19        t2id = Field(Integer, primary_key=True) 
     20        name = Field(String(30)) 
     21        tbl1 = ManyToOne(Table1) 
    2722 
    28 def test_nonset_attr(): 
    29     t1 = Table1(name='test2') 
    30     t1.from_dict({}) 
    31     assert t1.name == 'test2' 
     23    class Table3(Entity): 
     24        t3id = Field(Integer, primary_key=True) 
     25        name = Field(String(30)) 
     26        tbl1 = ManyToOne(Table1) 
    3227 
    33 def test_set_rel(): 
    34     t1 = Table1() 
    35     t1.from_dict(dict(tbl3={'name':'bob'})) 
    36     assert t1.tbl3.name == 'bob' 
     28    setup_all(True) 
    3729 
    38 def test_remove_rel(): 
    39     t1 = Table1() 
    40     t1.tbl3 = Table3() 
    41     t1.from_dict(dict(tbl3=None)) 
    42     assert t1.tbl3 is None 
     30def teardown(): 
     31    cleanup_all(True) 
    4332 
    44 def test_update_rel(): 
    45     t1 = Table1() 
    46     t1.tbl3 = Table3(name='fred') 
    47     t1.from_dict(dict(tbl3={'name':'bob'})) 
    48     assert t1.tbl3.name == 'bob' 
     33class TestDeepSet(object): 
     34    def test_set_attr(self): 
     35        t1 = Table1() 
     36        t1.from_dict(dict(name='test1')) 
     37        assert t1.name == 'test1' 
    4938 
    50 def test_extend_list(): 
    51     t1 = Table1() 
    52     t1.from_dict(dict(tbl2s=[{'name':'test3'}])) 
    53     assert len(t1.tbl2s) == 1 
    54     assert t1.tbl2s[0].name == 'test3' 
     39    def test_nonset_attr(self): 
     40        t1 = Table1(name='test2') 
     41        t1.from_dict({}) 
     42        assert t1.name == 'test2' 
    5543 
    56 def test_truncate_list(): 
    57     t1 = Table1() 
    58     t2 = Table2() 
    59     t1.tbl2s.append(t2) 
    60     el.session.commit() 
    61     t1.from_dict(dict(tbl2s=[])) 
    62     assert len(t1.tbl2s) == 0 
     44    def test_set_rel(self): 
     45        t1 = Table1() 
     46        t1.from_dict(dict(tbl3={'name': 'bob'})) 
     47        assert t1.tbl3.name == 'bob' 
    6348 
    64 def test_update_list_item(): 
    65     t1 = Table1() 
    66     t2 = Table2() 
    67     t1.tbl2s.append(t2) 
    68     el.session.commit() 
    69     t1.from_dict(dict(tbl2s=[{'id':t2.id, 'name':'test4'}])) 
    70     assert len(t1.tbl2s) == 1 
    71     assert t1.tbl2s[0].name == 'test4' 
     49    def test_remove_rel(self): 
     50        t1 = Table1() 
     51        t1.tbl3 = Table3() 
     52        t1.from_dict(dict(tbl3=None)) 
     53        assert t1.tbl3 is None 
    7254 
    73 def test_invalid_update(): 
    74     t1 = Table1() 
    75     t2 = Table2() 
    76     t1.tbl2s.append(t2) 
    77     el.session.commit() 
    78     try: 
    79         t1.from_dict(dict(tbl2s=[{'id':t2.id+1}])) 
    80         assert False 
    81     except sa.exceptions.ArgumentError: 
    82         pass 
     55    def test_update_rel(self): 
     56        t1 = Table1() 
     57        t1.tbl3 = Table3(name='fred') 
     58        t1.from_dict(dict(tbl3={'name': 'bob'})) 
     59        assert t1.tbl3.name == 'bob' 
    8360 
    84 def test_to(): 
    85     t1 = Table1(id=50, name='test1') 
    86     assert t1.to_dict() == {'id':50, 'name':'test1'} 
     61    def test_extend_list(self): 
     62        t1 = Table1() 
     63        t1.from_dict(dict(tbl2s=[{'name': 'test3'}])) 
     64        assert len(t1.tbl2s) == 1 
     65        assert t1.tbl2s[0].name == 'test3' 
    8766 
    88 def test_to_deep(): 
    89     t1 = Table1(id=51, name='test2') 
    90     assert t1.to_dict(deep={'tbl2s':{}}) == \ 
    91             {'id':51, 'name':'test2', 'tbl2s':[]} 
     67    def test_truncate_list(self): 
     68        t1 = Table1() 
     69        t2 = Table2() 
     70        t1.tbl2s.append(t2) 
     71        session.commit() 
     72        t1.from_dict(dict(tbl2s=[])) 
     73        assert len(t1.tbl2s) == 0 
    9274 
    93 def test_to_deep2(): 
    94     t1 = Table1(id=52, name='test3') 
    95     t2 = Table2(id=50, name='test4') 
    96     t1.tbl2s.append(t2) 
    97     el.session.commit() 
    98     assert t1.to_dict(deep={'tbl2s':{}}) == \ 
    99             {'id':52, 'name':'test3', 'tbl2s':[{'id':50, 'name':'test4'}]} 
     75    def test_update_list_item(self): 
     76        t1 = Table1() 
     77        t2 = Table2() 
     78        t1.tbl2s.append(t2) 
     79        session.commit() 
     80        t1.from_dict(dict(tbl2s=[{'t2id': t2.t2id, 'name': 'test4'}])) 
     81        assert len(t1.tbl2s) == 1 
     82        assert t1.tbl2s[0].name == 'test4' 
    10083 
    101 def test_to_deep3(): 
    102     t1 = Table1(id=53, name='test2') 
    103     t1.tbl3 = Table3(id=50, name='wobble') 
    104     el.session.commit() 
    105     assert t1.to_dict(deep={'tbl3':{}}) == \ 
    106             {'id':53, 'name':'test2', 'tbl3':{'id':50,'name':'wobble'}} 
     84    def test_invalid_update(self): 
     85        t1 = Table1() 
     86        t2 = Table2() 
     87        t1.tbl2s.append(t2) 
     88        session.commit() 
     89        try: 
     90            t1.from_dict(dict(tbl2s=[{'t2id':t2.t2id+1}])) 
     91            assert False 
     92        except: 
     93            pass 
     94 
     95    def test_to(self): 
     96        t1 = Table1(t1id=50, name='test1') 
     97        assert t1.to_dict() == {'t1id': 50, 'name': 'test1'} 
     98 
     99    def test_to_deep(self): 
     100        t1 = Table1(t1id=51, name='test2') 
     101        assert t1.to_dict(deep={'tbl2s':{}}) == \ 
     102                {'t1id': 51, 'name': 'test2', 'tbl2s': []} 
     103 
     104    def test_to_deep2(self): 
     105        t1 = Table1(t1id=52, name='test3') 
     106        t2 = Table2(t2id=50, name='test4') 
     107        t1.tbl2s.append(t2) 
     108        session.commit() 
     109        assert t1.to_dict(deep={'tbl2s':{}}) == \ 
     110                {'t1id': 52, 
     111                 'name': 'test3', 
     112                 'tbl2s': [{'t2id': 50, 'name': 'test4'}]} 
     113 
     114    def test_to_deep3(self): 
     115        t1 = Table1(t1id=53, name='test2') 
     116        t1.tbl3 = Table3(t3id=50, name='wobble') 
     117        session.commit() 
     118        assert t1.to_dict(deep={'tbl3':{}}) == \ 
     119                {'t1id': 53, 
     120                 'name': 'test2', 
     121                 'tbl3': {'t3id': 50, 'name': 'wobble'}} 
     122 
     123class TestSetOnAliasedColumn(object): 
     124    def setup(self): 
     125        metadata.bind = 'sqlite:///' 
     126 
     127    def teardown(self): 
     128        cleanup_all(True) 
     129 
     130    def test_set_on_aliased_column(self): 
     131        class A(Entity): 
     132            name = Field(String(60), colname='strName') 
     133 
     134        setup_all(True) 
     135 
     136        a = A() 
     137        a.set(name='Aye') 
     138 
     139        assert a.name == 'Aye' 
     140        session.flush() 
     141        session.clear() 
     142