diff --git a/module_build_service/models.py b/module_build_service/models.py index ca263a1c..3b9aaaf7 100644 --- a/module_build_service/models.py +++ b/module_build_service/models.py @@ -1093,7 +1093,7 @@ class ModuleBuild(MBSBase): def get_buildrequired_base_modules(self, db_session): """ - Find the base modules in the modulemd's xmd section. + Find the base modules in the modulemd's xmd/mbs/buildrequires section. :param db_session: the SQLAlchemy database session to use to query :return: a list of ModuleBuild objects of the base modules that are buildrequired with the @@ -1107,7 +1107,7 @@ class ModuleBuild(MBSBase): try: bm_dict = xmd["mbs"]["buildrequires"].get(bm) except KeyError: - raise RuntimeError("The module's mmd is missing information in the xmd section") + raise RuntimeError("The module's mmd is missing xmd/mbs or xmd/mbs/buildrequires.") if not bm_dict: continue @@ -1139,6 +1139,40 @@ class ModuleBuild(MBSBase): self.state_reason, ) + def update_virtual_streams(self, db_session, virtual_streams): + """Add and remove virtual streams to and from this build + + If a virtual stream is only associated with this build, remove it from + database as well. + + :param db_session: SQLAlchemy session object. + :param virtual_streams: list of virtual streams names used to update + this build's virtual streams. + :type virtual_streams: list[str] + """ + orig_virtual_streams = set(item.name for item in self.virtual_streams) + new_virtual_streams = set(virtual_streams) + + dropped_virtual_streams = orig_virtual_streams - new_virtual_streams + newly_added_virtual_streams = new_virtual_streams - orig_virtual_streams + + for stream_name in newly_added_virtual_streams: + virtual_stream = VirtualStream.get_by_name(db_session, stream_name) + if not virtual_stream: + virtual_stream = VirtualStream(name=stream_name) + self.virtual_streams.append(virtual_stream) + + for stream_name in dropped_virtual_streams: + virtual_stream = VirtualStream.get_by_name(db_session, stream_name) + only_associated_with_self = ( + len(virtual_stream.module_builds) == 1 + and virtual_stream.module_builds[0].id == self.id + ) + + self.virtual_streams.remove(virtual_stream) + if only_associated_with_self: + db_session.delete(virtual_stream) + class VirtualStream(MBSBase): __tablename__ = "virtual_streams" @@ -1151,6 +1185,10 @@ class VirtualStream(MBSBase): def __repr__(self): return "".format(self.id, self.name) + @classmethod + def get_by_name(cls, db_session, name): + return db_session.query(cls).filter_by(name=name).first() + class ModuleArch(MBSBase): __tablename__ = "module_arches" diff --git a/module_build_service/utils/general.py b/module_build_service/utils/general.py index 74d4e7c5..71823d5c 100644 --- a/module_build_service/utils/general.py +++ b/module_build_service/utils/general.py @@ -412,6 +412,8 @@ def import_mmd(db_session, mmd, check_buildrequires=True): The ModuleBuild.owner is set to "mbs_import". :param db_session: SQLAlchemy session object. + :param mmd: module metadata being imported into database. + :type mmd: Modulemd.ModuleStream :param bool check_buildrequires: When True, checks that the buildrequires defined in the MMD have matching records in the `mmd["xmd"]["mbs"]["buildrequires"]` and also fills in the `ModuleBuild.buildrequires` according to this data. @@ -419,22 +421,43 @@ def import_mmd(db_session, mmd, check_buildrequires=True): log messages collected during import (list) :rtype: tuple """ + xmd = mmd.get_xmd() + # Set some defaults in xmd["mbs"] if they're not provided by the user + if "mbs" not in xmd: + xmd["mbs"] = {"mse": True} + if not mmd.get_context(): mmd.set_context(models.DEFAULT_MODULE_CONTEXT) + + # NSVC is used for logging purpose later. + nsvc = mmd.get_nsvc() + if nsvc is None: + msg = "Both the name and stream must be set for the modulemd being imported." + log.error(msg) + raise UnprocessableEntity(msg) + name = mmd.get_module_name() stream = mmd.get_stream_name() version = str(mmd.get_version()) context = mmd.get_context() - try: - disttag_marking = mmd.get_xmd()["mbs"]["disttag_marking"] - except (ValueError, KeyError): - disttag_marking = None + xmd_mbs = xmd["mbs"] - try: - virtual_streams = mmd.get_xmd()["mbs"]["virtual_streams"] - except (ValueError, KeyError): - virtual_streams = [] + disttag_marking = xmd_mbs.get("disttag_marking") + + # If it is a base module, then make sure the value that will be used in the RPM disttags + # doesn't contain a dash since a dash isn't allowed in the release field of the NVR + if name in conf.base_module_names: + if disttag_marking and "-" in disttag_marking: + msg = "The disttag_marking cannot contain a dash" + log.error(msg) + raise UnprocessableEntity(msg) + if not disttag_marking and "-" in stream: + msg = "The stream cannot contain a dash unless disttag_marking is set" + log.error(msg) + raise UnprocessableEntity(msg) + + virtual_streams = xmd_mbs.get("virtual_streams", []) # Verify that the virtual streams are the correct type if virtual_streams and ( @@ -445,54 +468,34 @@ def import_mmd(db_session, mmd, check_buildrequires=True): log.error(msg) raise UnprocessableEntity(msg) - # If it is a base module, then make sure the value that will be used in the RPM disttags - # doesn't contain a dash since a dash isn't allowed in the release field of the NVR - if name in conf.base_module_names: - if disttag_marking and "-" in disttag_marking: - msg = "The disttag_marking cannot contain a dash" - log.error(msg) - raise UnprocessableEntity(msg) - elif not disttag_marking and "-" in stream: - msg = "The stream cannot contain a dash unless disttag_marking is set" - log.error(msg) - raise UnprocessableEntity(msg) + if check_buildrequires: + deps = mmd.get_dependencies() + if len(deps) > 1: + raise UnprocessableEntity( + "The imported module's dependencies list should contain just one element") + + if "buildrequires" not in xmd_mbs: + # Always set buildrequires if it is not there, because + # get_buildrequired_base_modules requires xmd/mbs/buildrequires exists. + xmd_mbs["buildrequires"] = {} + mmd.set_xmd(xmd) + + if len(deps) > 0: + brs = set(deps[0].get_buildtime_modules()) + xmd_brs = set(xmd_mbs["buildrequires"].keys()) + if brs - xmd_brs: + raise UnprocessableEntity( + "The imported module buildrequires other modules, but the metadata in the " + 'xmd["mbs"]["buildrequires"] dictionary is missing entries' + ) + + if "koji_tag" not in xmd_mbs: + log.warning("'koji_tag' is not set in xmd['mbs'] for module {}".format(nsvc)) + log.warning("koji_tag will be set to None for imported module build.") # Log messages collected during import msgs = [] - # NSVC is used for logging purpose later. - try: - nsvc = ":".join([name, stream, version, context]) - except TypeError: - msg = "Incomplete NSVC: {}:{}:{}:{}".format(name, stream, version, context) - log.error(msg) - raise UnprocessableEntity(msg) - - if len(mmd.get_dependencies()) > 1: - raise UnprocessableEntity( - "The imported module's dependencies list should contain just one element") - - xmd = mmd.get_xmd() - # Set some defaults in xmd["mbs"] if they're not provided by the user - if "mbs" not in xmd: - xmd["mbs"] = {"mse": True} - - if check_buildrequires and mmd.get_dependencies(): - brs = set(mmd.get_dependencies()[0].get_buildtime_modules()) - xmd_brs = set(xmd["mbs"].get("buildrequires", {}).keys()) - if brs - xmd_brs: - raise UnprocessableEntity( - "The imported module buildrequires other modules, but the metadata in the " - 'xmd["mbs"]["buildrequires"] dictionary is missing entries' - ) - elif "buildrequires" not in xmd["mbs"]: - xmd["mbs"]["buildrequires"] = {} - mmd.set_xmd(xmd) - - koji_tag = xmd["mbs"].get("koji_tag") - if koji_tag is None: - log.warning("'koji_tag' is not set in xmd['mbs'] for module {}".format(nsvc)) - # Get the ModuleBuild from DB. build = models.ModuleBuild.get_build_from_nsvc(db_session, name, stream, version, context) if build: @@ -501,11 +504,12 @@ def import_mmd(db_session, mmd, check_buildrequires=True): msgs.append(msg) else: build = models.ModuleBuild() + db_session.add(build) build.name = name build.stream = stream build.version = version - build.koji_tag = koji_tag + build.koji_tag = xmd_mbs.get("koji_tag") build.state = models.BUILD_STATES["ready"] build.modulemd = mmd_to_str(mmd) build.context = context @@ -523,19 +527,7 @@ def import_mmd(db_session, mmd, check_buildrequires=True): if base_module not in build.buildrequires: build.buildrequires.append(base_module) - db_session.add(build) - db_session.commit() - - for virtual_stream in virtual_streams: - vs_obj = db_session.query(models.VirtualStream).filter_by(name=virtual_stream).first() - if not vs_obj: - vs_obj = models.VirtualStream(name=virtual_stream) - db_session.add(vs_obj) - db_session.commit() - - if vs_obj not in build.virtual_streams: - build.virtual_streams.append(vs_obj) - db_session.add(build) + build.update_virtual_streams(db_session, virtual_streams) db_session.commit() diff --git a/tests/test_utils/test_utils.py b/tests/test_utils/test_utils.py index 0b7e5922..fc79dca3 100644 --- a/tests/test_utils/test_utils.py +++ b/tests/test_utils/test_utils.py @@ -455,6 +455,63 @@ class TestUtils: else: module_build_service.utils.import_mmd(db_session, mmd) + def test_import_mmd_remove_dropped_virtual_streams(self, db_session): + mmd = load_mmd(read_staged_data("formatted_testmodule")) + + # Add some virtual streams + xmd = mmd.get_xmd() + xmd["mbs"]["virtual_streams"] = ["f28", "f29", "f30"] + mmd.set_xmd(xmd) + + # Import mmd into database to simulate the next step to reimport a module + module_build_service.utils.general.import_mmd(db_session, mmd) + + # Now, remove some virtual streams from module metadata + xmd = mmd.get_xmd() + xmd["mbs"]["virtual_streams"] = ["f28", "f29"] # Note that, f30 is removed + mmd.set_xmd(xmd) + + # Test import modulemd again and the f30 should be removed from database. + module_build, _ = module_build_service.utils.general.import_mmd(db_session, mmd) + + db_session.refresh(module_build) + assert ["f28", "f29"] == sorted(item.name for item in module_build.virtual_streams) + assert 0 == db_session.query(models.VirtualStream).filter_by(name="f30").count() + + def test_import_mmd_dont_remove_dropped_virtual_streams_associated_with_other_modules( + self, db_session + ): + mmd = load_mmd(read_staged_data("formatted_testmodule")) + # Add some virtual streams to this module metadata + xmd = mmd.get_xmd() + xmd["mbs"]["virtual_streams"] = ["f28", "f29", "f30"] + mmd.set_xmd(xmd) + module_build_service.utils.general.import_mmd(db_session, mmd) + + # Import another module which has overlapping virtual streams + another_mmd = load_mmd(read_staged_data("formatted_testmodule-more-components")) + # Add some virtual streams to this module metadata + xmd = another_mmd.get_xmd() + xmd["mbs"]["virtual_streams"] = ["f29", "f30"] + another_mmd.set_xmd(xmd) + another_module_build, _ = module_build_service.utils.general.import_mmd( + db_session, another_mmd) + + # Now, remove f30 from mmd + xmd = mmd.get_xmd() + xmd["mbs"]["virtual_streams"] = ["f28", "f29"] + mmd.set_xmd(xmd) + + # Reimport formatted_testmodule again + module_build, _ = module_build_service.utils.general.import_mmd(db_session, mmd) + + db_session.refresh(module_build) + assert ["f28", "f29"] == sorted(item.name for item in module_build.virtual_streams) + + # The overlapped f30 should be still there. + db_session.refresh(another_module_build) + assert ["f29", "f30"] == sorted(item.name for item in another_module_build.virtual_streams) + def test_get_rpm_release_mse(self, db_session): init_data(contexts=True) diff --git a/tests/test_views/test_views.py b/tests/test_views/test_views.py index f4e89713..33b33c90 100644 --- a/tests/test_views/test_views.py +++ b/tests/test_views/test_views.py @@ -2019,7 +2019,8 @@ class TestViews: data = json.loads(rv.data) assert data["error"] == "Unprocessable Entity" - assert data["message"] == "Incomplete NSVC: None:None:0:00000000" + expected_msg = "Both the name and stream must be set for the modulemd being imported." + assert data["message"] == expected_msg @pytest.mark.parametrize("api_version", [1, 2]) @patch("module_build_service.auth.get_user", return_value=import_module_user)