[Date Prev][Date Next] [Thread Prev][Thread Next] [Date Index] [Thread Index]

[dak/master] Implement ORMObject.session() and .clone().



- session() is syntactic sugar
- clone() is needed for safe multithreading

Signed-off-by: Torsten Werner <twerner@debian.org>
---
 daklib/dbconn.py        |   38 +++++++++++++++++++++++++++++++++++++-
 tests/dbtest_session.py |   37 ++++++++++++++++++++++++++++++++++---
 2 files changed, 71 insertions(+), 4 deletions(-)

diff --git a/daklib/dbconn.py b/daklib/dbconn.py
index aa71c18..df38b77 100755
--- a/daklib/dbconn.py
+++ b/daklib/dbconn.py
@@ -55,7 +55,7 @@ from inspect import getargspec
 import sqlalchemy
 from sqlalchemy import create_engine, Table, MetaData, Column, Integer, desc
 from sqlalchemy.orm import sessionmaker, mapper, relation, object_session, \
-    backref, MapperExtension, EXT_CONTINUE
+    backref, MapperExtension, EXT_CONTINUE, object_mapper
 from sqlalchemy import types as sqltypes
 
 # Don't remove this, we re-export the exceptions to scripts which import us
@@ -287,6 +287,42 @@ class ORMObject(object):
         '''
         return session.query(cls).get(primary_key)
 
+    def session(self, replace = False):
+        '''
+        Returns the current session that is associated with the object. May
+        return None is object is in detached state.
+        '''
+
+        return object_session(self)
+
+    def clone(self, session = None):
+        '''
+        Clones the current object in a new session and returns the new clone. A
+        fresh session is created if the optional session parameter is not
+        provided.
+
+        RATIONALE: SQLAlchemy's session is not thread safe. This method allows
+        cloning of an existing object to allow several threads to work with
+        their own instances of an ORMObject.
+
+        WARNING: Only persistent (committed) objects can be cloned.
+        '''
+
+        if session is None:
+            session = DBConn().session()
+        if self.session() is None:
+            raise RuntimeError('Method clone() failed for detached object:\n%s' %
+                self)
+        self.session().flush()
+        mapper = object_mapper(self)
+        primary_key = mapper.primary_key_from_instance(self)
+        object_class = self.__class__
+        new_object = session.query(object_class).get(primary_key)
+        if new_object is None:
+            raise RuntimeError( \
+                'Method clone() failed for non-persistent object:\n%s' % self)
+        return new_object
+
 __all__.append('ORMObject')
 
 ################################################################################
diff --git a/tests/dbtest_session.py b/tests/dbtest_session.py
index 7c378ce..72c2aff 100755
--- a/tests/dbtest_session.py
+++ b/tests/dbtest_session.py
@@ -2,9 +2,8 @@
 
 from db_test import DBDakTestCase
 
-from daklib.dbconn import Uid
+from daklib.dbconn import DBConn, Uid
 
-from sqlalchemy.orm import object_session
 from sqlalchemy.exc import InvalidRequestError
 
 import time
@@ -93,7 +92,6 @@ class SessionTestCase(DBDakTestCase):
         uid = Uid(uid = 'foobar')
         self.session.add(uid)
         self.assertTrue(uid in self.session)
-        self.assertEqual(self.session, object_session(uid))
         self.session.expunge(uid)
         self.assertTrue(uid not in self.session)
         # test close()
@@ -138,6 +136,39 @@ class SessionTestCase(DBDakTestCase):
         self.session.rollback()
         self.assertRaises(InvalidRequestError, self.refresh)
 
+    def test_session(self):
+        '''
+        Tests the ORMObject.session() method.
+        '''
+
+        uid = Uid(uid = 'foobar')
+        self.session.add(uid)
+        self.assertEqual(self.session, uid.session())
+
+    def test_clone(self):
+        '''
+        Tests the ORMObject.clone() method.
+        '''
+
+        uid1 = Uid(uid = 'foobar')
+        # no session yet
+        self.assertRaises(RuntimeError, uid1.clone)
+        self.session.add(uid1)
+        # object not persistent yet
+        self.assertRaises(RuntimeError, uid1.clone)
+        self.session.commit()
+        # test without session parameter
+        uid2 = uid1.clone()
+        self.assertTrue(uid1 is not uid2)
+        self.assertEqual(uid1.uid, uid2.uid)
+        self.assertTrue(uid2 not in uid1.session())
+        self.assertTrue(uid1 not in uid2.session())
+        # test with explicit session parameter
+        new_session = DBConn().session()
+        uid3 = uid1.clone(session = new_session)
+        self.assertEqual(uid1.uid, uid3.uid)
+        self.assertTrue(uid3 in new_session)
+
     def classes_to_clean(self):
         # We need to clean all Uid objects in case some test fails.
         return (Uid,)
-- 
1.5.6.5



Reply to: