[dak/master] Add method reset() to class DBConn() and test it.
Signed-off-by: Torsten Werner <twerner@debian.org>
---
daklib/dbconn.py | 11 ++++++++++-
tests/dbtest_multiproc.py | 45 +++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 55 insertions(+), 1 deletions(-)
create mode 100755 tests/dbtest_multiproc.py
diff --git a/daklib/dbconn.py b/daklib/dbconn.py
index 1bf5a0f..6cd84de 100755
--- a/daklib/dbconn.py
+++ b/daklib/dbconn.py
@@ -59,7 +59,7 @@ import sqlalchemy
from sqlalchemy import create_engine, Table, MetaData, Column, Integer, desc, \
Text, ForeignKey
from sqlalchemy.orm import sessionmaker, mapper, relation, object_session, \
- backref, MapperExtension, EXT_CONTINUE, object_mapper
+ backref, MapperExtension, EXT_CONTINUE, object_mapper, clear_mappers
from sqlalchemy import types as sqltypes
# Don't remove this, we re-export the exceptions to scripts which import us
@@ -3198,6 +3198,15 @@ class DBConn(object):
def session(self):
return self.db_smaker()
+ def reset(self):
+ '''
+ Resets the DBConn object. This function must be called by subprocesses
+ created by the multiprocessing module. See tests/dbtest_multiproc.py
+ for an example.
+ '''
+ clear_mappers()
+ self.__createconn()
+
__all__.append('DBConn')
diff --git a/tests/dbtest_multiproc.py b/tests/dbtest_multiproc.py
new file mode 100755
index 0000000..f4c2a37
--- /dev/null
+++ b/tests/dbtest_multiproc.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python
+
+from db_test import DBDakTestCase
+
+from daklib.dbconn import DBConn
+
+from multiprocessing import Pool
+from time import sleep
+import unittest
+
+def read_number():
+ DBConn().reset()
+ session = DBConn().session()
+ result = session.query('foo').from_statement('select 7 as foo').scalar()
+ sleep(0.1)
+ session.close()
+ return result
+
+class MultiProcTestCase(DBDakTestCase):
+ """
+ This TestCase checks that DBConn works with multiprocessing. A fresh
+ subprocess needs to call reset() on DBConn(). See function read_number()
+ for an example.
+ """
+
+ def save_result(self, result):
+ self.result += result
+
+ def test_seven(self):
+ '''
+ Test apply_async() with a database session.
+ '''
+ self.result = 0
+ pool = Pool()
+ pool.apply_async(read_number, (), callback = self.save_result)
+ pool.apply_async(read_number, (), callback = self.save_result)
+ pool.apply_async(read_number, (), callback = self.save_result)
+ pool.apply_async(read_number, (), callback = self.save_result)
+ pool.apply_async(read_number, (), callback = self.save_result)
+ pool.close()
+ pool.join()
+ self.assertEqual(5 * 7, self.result)
+
+if __name__ == '__main__':
+ unittest.main()
--
1.7.2.5
Reply to: