diff --git a/mlflow/tracking/_model_registry/utils.py b/mlflow/tracking/_model_registry/utils.py index 2101388915ece..123d5af0f85e6 100644 --- a/mlflow/tracking/_model_registry/utils.py +++ b/mlflow/tracking/_model_registry/utils.py @@ -24,6 +24,8 @@ get_db_info_from_uri, is_databricks_uri, ) +from mlflow.store._unity_catalog.registry.rest_store import UcModelRegistryStore +from mlflow.store._unity_catalog.registry.uc_oss_rest_store import UnityCatalogOssStore # NOTE: in contrast to tracking, we do not support the following ways to specify # the model registry URI: @@ -217,30 +219,28 @@ def _get_file_store(store_uri, **_): def _get_store_registry(): global _model_registry_store_registry - from mlflow.store._unity_catalog.registry.rest_store import UcModelRegistryStore - from mlflow.store._unity_catalog.registry.uc_oss_rest_store import UnityCatalogOssStore - if _model_registry_store_registry is not None: - return _model_registry_store_registry + registry = _model_registry_store_registry + if registry is not None: + return registry - _model_registry_store_registry = ModelRegistryStoreRegistry() - _model_registry_store_registry.register("databricks", _get_databricks_rest_store) - # Register a placeholder function that raises if users pass a registry URI with scheme - # "databricks-uc" - _model_registry_store_registry.register(_DATABRICKS_UNITY_CATALOG_SCHEME, UcModelRegistryStore) - _model_registry_store_registry.register(_OSS_UNITY_CATALOG_SCHEME, UnityCatalogOssStore) + registry = ModelRegistryStoreRegistry() + _model_registry_store_registry = registry + registry.register("databricks", _get_databricks_rest_store) + registry.register(_DATABRICKS_UNITY_CATALOG_SCHEME, UcModelRegistryStore) + registry.register(_OSS_UNITY_CATALOG_SCHEME, UnityCatalogOssStore) - for scheme in ["http", "https"]: - _model_registry_store_registry.register(scheme, _get_rest_store) + reg = registry.register + for scheme in ("http", "https"): + reg(scheme, _get_rest_store) for scheme in DATABASE_ENGINES: - _model_registry_store_registry.register(scheme, _get_sqlalchemy_store) + reg(scheme, _get_sqlalchemy_store) + for scheme in ("", "file"): + reg(scheme, _get_file_store) - for scheme in ["", "file"]: - _model_registry_store_registry.register(scheme, _get_file_store) - - _model_registry_store_registry.register_entrypoints() - return _model_registry_store_registry + registry.register_entrypoints() + return registry def _get_store(store_uri=None, tracking_uri=None): diff --git a/mlflow/tracking/registry.py b/mlflow/tracking/registry.py index 0aef285161d49..c4e6edc628445 100644 --- a/mlflow/tracking/registry.py +++ b/mlflow/tracking/registry.py @@ -48,16 +48,13 @@ def register(self, scheme, store_builder): def register_entrypoints(self): """Register tracking stores provided by other packages""" + _register = self.register for entrypoint in get_entry_points(self.group_name): try: - self.register(entrypoint.name, entrypoint.load()) + _register(entrypoint.name, entrypoint.load()) except (AttributeError, ImportError) as exc: - warnings.warn( - 'Failure attempting to register store for scheme "{}": {}'.format( - entrypoint.name, str(exc) - ), - stacklevel=2, - ) + msg = f'Failure attempting to register store for scheme "{entrypoint.name}": {exc}' + warnings.warn(msg, stacklevel=2) def get_store_builder(self, store_uri): """Get a store from the registry based on the scheme of store_uri