diff --git a/conf/config.py b/conf/config.py index e5212c57..3ebdbab0 100644 --- a/conf/config.py +++ b/conf/config.py @@ -99,8 +99,7 @@ class TestConfiguration(BaseConfiguration): BUILD_LOGS_NAME_FORMAT = 'build-{id}.log' LOG_BACKEND = 'console' LOG_LEVEL = 'debug' - SQLALCHEMY_DATABASE_URI = 'sqlite:///{0}'.format( - path.join(dbdir, 'tests', 'test_module_build_service.db')) + SQLALCHEMY_DATABASE_URI = 'sqlite://' DEBUG = True MESSAGING = 'in_memory' PDC_URL = 'https://pdc.fedoraproject.org/rest_api/v1' diff --git a/module_build_service/__init__.py b/module_build_service/__init__.py index da91231d..432580ea 100644 --- a/module_build_service/__init__.py +++ b/module_build_service/__init__.py @@ -43,6 +43,7 @@ for a number of tasks: import pkg_resources from flask import Flask, has_app_context, url_for from flask_sqlalchemy import SQLAlchemy +from sqlalchemy.pool import StaticPool from logging import getLogger from module_build_service.logger import ( @@ -63,7 +64,28 @@ app = Flask(__name__) app.wsgi_app = ReverseProxy(app.wsgi_app) conf = init_config(app) -db = SQLAlchemy(app) + + +class MBSSQLAlchemy(SQLAlchemy): + """ + Inherits from SQLAlchemy and if SQLite in-memory database is used, + sets the driver options so multiple threads can share the same database. + + This is used *only* during tests to make them faster. + """ + def apply_driver_hacks(self, app, info, options): + if info.drivername == 'sqlite' and info.database in (None, '', ':memory:'): + options['poolclass'] = StaticPool + options['connect_args'] = {'check_same_thread': False} + try: + del options['pool_size'] + except KeyError: + pass + + super(MBSSQLAlchemy, self).apply_driver_hacks(app, info, options) + + +db = MBSSQLAlchemy(app) def create_app(debug=False, verbose=False, quiet=False): diff --git a/module_build_service/models.py b/module_build_service/models.py index 160a1af8..9b459597 100644 --- a/module_build_service/models.py +++ b/module_build_service/models.py @@ -106,11 +106,31 @@ def _dummy_context_mgr(): yield None +def _setup_event_listeners(session): + """ + Starts listening for events related to database session. + """ + if not sqlalchemy.event.contains( + session, 'before_commit', session_before_commit_handlers): + sqlalchemy.event.listen(session, 'before_commit', + session_before_commit_handlers) + + @contextlib.contextmanager def make_session(conf): """ Yields new SQLAlchemy database sesssion. """ + + # Do not use scoped_session in case we are using in-memory database, + # because we want to use the same session across all threads to be able + # to use the same in-memory database in tests. + if conf.sqlalchemy_database_uri == 'sqlite://': + _setup_event_listeners(db.session) + yield db.session + db.session.commit() + return + # Needs to be set to create app_context. if 'SERVER_NAME' not in app.config or not app.config['SERVER_NAME']: app.config['SERVER_NAME'] = 'localhost' @@ -126,7 +146,7 @@ def make_session(conf): 'sqlalchemy.url': conf.sqlalchemy_database_uri, }) session = scoped_session(sessionmaker(bind=engine))() - sqlalchemy.event.listen(session, "before_commit", session_before_commit_handlers) + _setup_event_listeners(session) try: yield session session.commit() diff --git a/tests/__init__.py b/tests/__init__.py index b47914d3..b54266f0 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -51,7 +51,7 @@ patch_config() def clean_database(): - db.session.remove() + db.session.commit() db.drop_all() db.create_all() diff --git a/tests/test_build/test_build.py b/tests/test_build/test_build.py index 1d06bd43..038b9ef2 100644 --- a/tests/test_build/test_build.py +++ b/tests/test_build/test_build.py @@ -174,7 +174,7 @@ class FakeModuleBuilder(GenericBuilder): for nvr in artifacts: # tag_artifacts received a list of NVRs, but the tag message expects the # component name - artifact = models.ComponentBuild.query.filter_by(nvr=nvr).one().package + artifact = models.ComponentBuild.query.filter_by(nvr=nvr).first().package self._send_tag(artifact, dest_tag=dest_tag) @property