Show
Ignore:
Timestamp:
10/02/08 14:12:30 (4 years ago)
Author:
ged
Message:

- Added new column_names argument to the acts_as_versioned extension, allowing

to specify custom column names (inspired by a patch by Alex Bodnaru).

Files:
1 modified

Legend:

Unmodified
Added
Removed
  • elixir/trunk/elixir/ext/versioned.py

    r406 r409  
    8888class VersionedMapperExtension(MapperExtension): 
    8989    def before_insert(self, mapper, connection, instance): 
    90         instance.version = 1 
    91         instance.timestamp = datetime.now() 
     90        version_colname, timestamp_colname = \ 
     91            instance.__class__.__versioned_column_names__ 
     92        setattr(instance, version_colname, 1) 
     93        setattr(instance, timestamp_colname, datetime.now()) 
    9294        return EXT_CONTINUE 
    9395 
     
    102104        # data. 
    103105        ignored = instance.__class__.__ignored_fields__ 
     106        version_colname, timestamp_colname = \ 
     107            instance.__class__.__versioned_column_names__ 
    104108        for key in instance.table.c.keys(): 
    105109            if key in ignored: 
     
    110114                connection.execute( 
    111115                    instance.__class__.__history_table__.insert(), dict_values) 
    112                 instance.version = instance.version + 1 
    113                 instance.timestamp = datetime.now() 
     116                old_version = getattr(instance, version_colname) 
     117                setattr(instance, version_colname, old_version + 1) 
     118                setattr(instance, timestamp_colname, datetime.now()) 
    114119                break 
    115120 
     
    132137class VersionedEntityBuilder(EntityBuilder): 
    133138 
    134     def __init__(self, entity, ignore=[], check_concurrent=False): 
     139    def __init__(self, entity, ignore=None, check_concurrent=False, 
     140                 column_names=None): 
    135141        self.entity = entity 
    136142        self.add_mapper_extension(versioned_mapper_extension) 
     
    140146 
    141147        # Changes in these fields will be ignored 
    142         ignore.extend(['version', 'timestamp']) 
     148        if column_names is None: 
     149            column_names = ['version', 'timestamp'] 
     150        entity.__versioned_column_names__ = column_names 
     151        if ignore is None: 
     152            ignore = [] 
     153        ignore.extend(column_names) 
    143154        entity.__ignored_fields__ = ignore 
    144155 
    145156    def create_non_pk_cols(self): 
    146157        # add a version column to the entity, along with a timestamp 
    147         self.add_table_column(Column('version', Integer)) 
    148         self.add_table_column(Column('timestamp', DateTime)) 
     158        version_colname, timestamp_colname = \ 
     159            self.entity.__versioned_column_names__ 
     160        #XXX: fail in case the columns already exist? 
     161        #col_names = [col.name for col in self.entity._descriptor.columns] 
     162        #if version_colname not in col_names: 
     163        self.add_table_column(Column(version_colname, Integer)) 
     164        #if timestamp_colname not in col_names: 
     165        self.add_table_column(Column(timestamp_colname, DateTime)) 
    149166 
    150167        # add a concurrent_version column to the entity, if required 
     
    155172    def after_table(self): 
    156173        entity = self.entity 
     174        version_colname, timestamp_colname = \ 
     175            entity.__versioned_column_names__ 
    157176 
    158177        # look for events 
     
    163182 
    164183        # create a history table for the entity 
    165         #TODO: fail more noticeably in case there is a version col 
     184        skipped_columns = [version_colname] 
     185        if self.check_concurrent: 
     186            skipped_columns.append('concurrent_version') 
     187 
    166188        columns = [ 
    167189            column.copy() for column in entity.table.c 
    168             if column.name not in ('version', 'concurrent_version') 
     190            if column.name not in skipped_columns 
    169191        ] 
    170         columns.append(Column('version', Integer, primary_key=True)) 
     192        columns.append(Column(version_colname, Integer, primary_key=True)) 
    171193        table = Table(entity.table.name + '_history', entity.table.metadata, 
    172194            *columns 
     
    183205        mapper(Version, entity.__history_table__) 
    184206 
     207        version_col = getattr(table.c, version_colname) 
     208        timestamp_col = getattr(table.c, timestamp_colname) 
     209 
    185210        # attach utility methods and properties to the entity 
    186211        def get_versions(self): 
    187212            v = object_session(self).query(Version) \ 
    188213                                    .filter(get_history_where(self)) \ 
    189                                     .order_by(Version.version) \ 
     214                                    .order_by(version_col) \ 
    190215                                    .all() 
    191216            # history contains all the previous records. 
     
    197222            # if the passed in timestamp is older than our current version's 
    198223            # time stamp, then the most recent version is our current version 
    199             if self.timestamp < dt: 
     224            if getattr(self, timestamp_colname) < dt: 
    200225                return self 
    201226 
     
    205230            query = sess.query(Version) \ 
    206231                        .filter(and_(get_history_where(self), 
    207                                      Version.timestamp <= dt)) \ 
    208                         .order_by(desc(Version.timestamp)).limit(1) 
     232                                     timestamp_col <= dt)) \ 
     233                        .order_by(desc(timestamp_col)).limit(1) 
    209234            return query.first() 
    210235 
    211236        def revert_to(self, to_version): 
    212237            if isinstance(to_version, Version): 
    213                 to_version = to_version.version 
    214  
    215             hist = entity.__history_table__ 
    216             old_version = hist.select(and_( 
     238                to_version = getattr(to_version, version_colname) 
     239 
     240            old_version = table.select(and_( 
    217241                get_history_where(self), 
    218                 hist.c.version == to_version 
     242                version_col == to_version 
    219243            )).execute().fetchone() 
    220244 
     
    223247            ) 
    224248 
    225             hist.delete(and_(get_history_where(self), 
    226                              hist.c.version >= to_version)).execute() 
     249            table.delete(and_(get_history_where(self), 
     250                              version_col >= to_version)).execute() 
    227251            self.expire() 
    228252            for event in after_revert_events: 
     
    230254 
    231255        def revert(self): 
    232             assert self.version > 1 
    233             self.revert_to(self.version - 1) 
     256            assert getattr(self, version_colname) > 1 
     257            self.revert_to(getattr(self, version_colname) - 1) 
    234258 
    235259        def compare_with(self, version): 
    236260            differences = {} 
    237261            for column in self.table.c: 
    238                 if column.name in ('version', 'concurrent_version'): 
     262                if column.name in (version_colname, 'concurrent_version'): 
    239263                    continue 
    240264                this = getattr(self, column.name)