From 3a7786c01b53eaeb3597aa36bf6f5b57243724d4 Mon Sep 17 00:00:00 2001 From: mprahl Date: Mon, 22 Apr 2019 12:40:14 -0400 Subject: [PATCH] Make _get_module consistent across resolvers This also adds additional code in the event a module is not returned. --- module_build_service/resolver/DBResolver.py | 14 ++++++++++---- module_build_service/resolver/MBSResolver.py | 4 +++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/module_build_service/resolver/DBResolver.py b/module_build_service/resolver/DBResolver.py index 80a8e916..6ed4bc7c 100644 --- a/module_build_service/resolver/DBResolver.py +++ b/module_build_service/resolver/DBResolver.py @@ -40,14 +40,18 @@ class DBResolver(GenericResolver): def __init__(self, config): self.config = config - def _get_module(self, name, stream, version, context, strict=False): + def _get_module( + self, name, stream, version, context, state=models.BUILD_STATES['ready'], strict=False, + ): with models.make_session(self.config) as session: mb = models.ModuleBuild.get_build_from_nsvc( - session, name, stream, version, context) - if mb is None and strict: + session, name, stream, version, context, state=state) + if mb: + return mb.extended_json() + + if strict: raise UnprocessableEntity( 'Cannot find any module builds for %s:%s' % (name, stream)) - return mb.extended_json() def get_module_modulemds(self, name, stream, version=None, context=None, strict=False, stream_version_lte=False): @@ -68,6 +72,8 @@ class DBResolver(GenericResolver): """ if version and context: mmd = self._get_module(name, stream, version, context, strict=strict) + if mmd is None: + return return [self.extract_modulemd(mmd['modulemd'])] with models.make_session(self.config) as session: diff --git a/module_build_service/resolver/MBSResolver.py b/module_build_service/resolver/MBSResolver.py index c538796f..28ecc90d 100644 --- a/module_build_service/resolver/MBSResolver.py +++ b/module_build_service/resolver/MBSResolver.py @@ -122,7 +122,9 @@ class MBSResolver(GenericResolver): return modules def _get_module(self, name, stream, version, context, state="ready", strict=False): - return self._get_modules(name, stream, version, context, state, strict)[0] + rv = self._get_modules(name, stream, version, context, state, strict) + if rv: + return rv[0] def get_module_modulemds(self, name, stream, version=None, context=None, strict=False, stream_version_lte=False):