Skip to content

MlflowAzureMLHook

kedro_azureml_pipeline.hooks.MlflowAzureMLHook

Coordinates kedro-mlflow inside Azure ML pipeline component jobs.

Lifecycle
  1. after_context_created (tryfirst): pre-sets MLFLOW_EXPERIMENT_NAME so kedro-mlflow picks it up.
  2. before_pipeline_run: tags the active MLflow child run with node name, pipeline name, and kedro environment.
  3. on_pipeline_error: tags the run with error information.

See Also

AzureMLPipelineGenerator : Injects MLflow env vars. AzureMLLocalRunHook : Companion hook for dataset config.

Source Code

Show/Hide source
class MlflowAzureMLHook:
    """Coordinates kedro-mlflow inside Azure ML pipeline component jobs.

    Lifecycle
    ---------
    1. ``after_context_created`` (``tryfirst``): pre-sets
       ``MLFLOW_EXPERIMENT_NAME`` so kedro-mlflow picks it up.
    2. ``before_pipeline_run``: tags the active MLflow child run with
       node name, pipeline name, and kedro environment.
    3. ``on_pipeline_error``: tags the run with error information.

    See Also
    --------
    [AzureMLPipelineGenerator][kedro_azureml_pipeline.generator.AzureMLPipelineGenerator] : Injects MLflow env vars.
    [AzureMLLocalRunHook][kedro_azureml_pipeline.hooks.AzureMLLocalRunHook] : Companion hook for dataset config.
    """

    @hook_impl(tryfirst=True)
    def after_context_created(self, context) -> None:
        """Pre-set ``MLFLOW_EXPERIMENT_NAME`` for kedro-mlflow.

        Parameters
        ----------
        context : KedroContext
            Kedro project context.
        """
        if not _is_mlflow_integration_active():
            return

        experiment_name = os.environ.get(KEDRO_AZUREML_MLFLOW_EXPERIMENT_NAME)
        if experiment_name:
            os.environ["MLFLOW_EXPERIMENT_NAME"] = experiment_name
            logger.info("kedro-azureml-pipeline: set MLFLOW_EXPERIMENT_NAME=%s", experiment_name)

    @hook_impl(tryfirst=True)
    def before_pipeline_run(self, run_params, pipeline, catalog) -> None:
        """Tag the active MLflow run with Kedro metadata.

        Parameters
        ----------
        run_params : dict
            Parameters passed to the run command.
        pipeline : Pipeline
            Pipeline about to be run.
        catalog : DataCatalog
            Data catalog.
        """
        if not _is_mlflow_integration_active():
            return

        try:
            import mlflow
        except ImportError:
            logger.warning("kedro-azureml-pipeline: mlflow is not installed, skipping run tagging")
            return

        # Ensure the correct experiment is active before kedro-mlflow's hook
        # fires.  Without this, kedro-mlflow would resolve the experiment from
        # mlflow.yml (which may differ from the AzureML job experiment) and
        # pass a mismatched experiment_id to start_run(), causing an
        # MlflowException when MLFLOW_RUN_ID is set by AzureML.
        experiment_name = os.environ.get(KEDRO_AZUREML_MLFLOW_EXPERIMENT_NAME)
        if experiment_name and mlflow.active_run() is None:
            run_id = os.environ.get("MLFLOW_RUN_ID")
            if run_id:
                # The run already exists in AzureML under its own experiment.
                # kedro-mlflow may have called set_experiment() with a name
                # from mlflow.yml, setting _active_experiment_id to the wrong
                # value.  We must align it with the run's actual experiment
                # before calling start_run() to avoid an ID mismatch.
                client = mlflow.MlflowClient()
                run_info = client.get_run(run_id)
                mlflow.set_experiment(experiment_id=run_info.info.experiment_id)
                mlflow.start_run(run_id=run_id)
                logger.info(
                    "kedro-azureml-pipeline: resumed MLflow run %s (experiment_id=%s)",
                    run_id,
                    run_info.info.experiment_id,
                )
            else:
                mlflow.set_experiment(experiment_name)

        active_run = mlflow.active_run()
        if active_run is None:
            return

        node_name = os.environ.get(KEDRO_AZUREML_MLFLOW_NODE_NAME, "")
        run_name = os.environ.get(KEDRO_AZUREML_MLFLOW_RUN_NAME, "")
        kedro_env = os.environ.get("KEDRO_ENV", "")

        tags = {}
        if node_name:
            tags["kedro.node_name"] = node_name
        if run_name:
            tags["kedro.pipeline_run_name"] = run_name
        if kedro_env:
            tags["kedro.env"] = kedro_env
        if run_params.get("pipeline_name"):
            tags["kedro.pipeline_name"] = run_params["pipeline_name"]

        if tags:
            mlflow.set_tags(tags)
            logger.info("kedro-azureml-pipeline: tagged MLflow run with %s", tags)

        # Set the child run name to include the node name for clarity
        if node_name:
            child_run_name = f"{run_name} :: {node_name}" if run_name else node_name
            mlflow.MlflowClient().set_tag(active_run.info.run_id, "mlflow.runName", child_run_name)

    @hook_impl
    def on_pipeline_error(self, error, run_params, pipeline, catalog) -> None:
        """Tag the MLflow run with error details.

        Parameters
        ----------
        error : Exception
            The error that occurred.
        run_params : dict
            Parameters passed to the run command.
        pipeline : Pipeline
            Pipeline that failed.
        catalog : DataCatalog
            Data catalog.
        """
        if not _is_mlflow_integration_active():
            return

        try:
            import mlflow
        except ImportError:
            return

        active_run = mlflow.active_run()
        if active_run is None:
            return

        error_msg = str(error)[:250]
        mlflow.set_tag("kedro.error", error_msg)
        node_name = os.environ.get(KEDRO_AZUREML_MLFLOW_NODE_NAME, "")
        if node_name:
            mlflow.set_tag("kedro.failed_node", node_name)

Methods

after_context_created(context)

Pre-set MLFLOW_EXPERIMENT_NAME for kedro-mlflow.

Parameters
Name Type Description Default
context KedroContext

Kedro project context.

required
Source Code
Show/Hide source
@hook_impl(tryfirst=True)
def after_context_created(self, context) -> None:
    """Pre-set ``MLFLOW_EXPERIMENT_NAME`` for kedro-mlflow.

    Parameters
    ----------
    context : KedroContext
        Kedro project context.
    """
    if not _is_mlflow_integration_active():
        return

    experiment_name = os.environ.get(KEDRO_AZUREML_MLFLOW_EXPERIMENT_NAME)
    if experiment_name:
        os.environ["MLFLOW_EXPERIMENT_NAME"] = experiment_name
        logger.info("kedro-azureml-pipeline: set MLFLOW_EXPERIMENT_NAME=%s", experiment_name)

before_pipeline_run(run_params, pipeline, catalog)

Tag the active MLflow run with Kedro metadata.

Parameters
Name Type Description Default
run_params dict

Parameters passed to the run command.

required
pipeline Pipeline

Pipeline about to be run.

required
catalog DataCatalog

Data catalog.

required
Source Code
Show/Hide source
@hook_impl(tryfirst=True)
def before_pipeline_run(self, run_params, pipeline, catalog) -> None:
    """Tag the active MLflow run with Kedro metadata.

    Parameters
    ----------
    run_params : dict
        Parameters passed to the run command.
    pipeline : Pipeline
        Pipeline about to be run.
    catalog : DataCatalog
        Data catalog.
    """
    if not _is_mlflow_integration_active():
        return

    try:
        import mlflow
    except ImportError:
        logger.warning("kedro-azureml-pipeline: mlflow is not installed, skipping run tagging")
        return

    # Ensure the correct experiment is active before kedro-mlflow's hook
    # fires.  Without this, kedro-mlflow would resolve the experiment from
    # mlflow.yml (which may differ from the AzureML job experiment) and
    # pass a mismatched experiment_id to start_run(), causing an
    # MlflowException when MLFLOW_RUN_ID is set by AzureML.
    experiment_name = os.environ.get(KEDRO_AZUREML_MLFLOW_EXPERIMENT_NAME)
    if experiment_name and mlflow.active_run() is None:
        run_id = os.environ.get("MLFLOW_RUN_ID")
        if run_id:
            # The run already exists in AzureML under its own experiment.
            # kedro-mlflow may have called set_experiment() with a name
            # from mlflow.yml, setting _active_experiment_id to the wrong
            # value.  We must align it with the run's actual experiment
            # before calling start_run() to avoid an ID mismatch.
            client = mlflow.MlflowClient()
            run_info = client.get_run(run_id)
            mlflow.set_experiment(experiment_id=run_info.info.experiment_id)
            mlflow.start_run(run_id=run_id)
            logger.info(
                "kedro-azureml-pipeline: resumed MLflow run %s (experiment_id=%s)",
                run_id,
                run_info.info.experiment_id,
            )
        else:
            mlflow.set_experiment(experiment_name)

    active_run = mlflow.active_run()
    if active_run is None:
        return

    node_name = os.environ.get(KEDRO_AZUREML_MLFLOW_NODE_NAME, "")
    run_name = os.environ.get(KEDRO_AZUREML_MLFLOW_RUN_NAME, "")
    kedro_env = os.environ.get("KEDRO_ENV", "")

    tags = {}
    if node_name:
        tags["kedro.node_name"] = node_name
    if run_name:
        tags["kedro.pipeline_run_name"] = run_name
    if kedro_env:
        tags["kedro.env"] = kedro_env
    if run_params.get("pipeline_name"):
        tags["kedro.pipeline_name"] = run_params["pipeline_name"]

    if tags:
        mlflow.set_tags(tags)
        logger.info("kedro-azureml-pipeline: tagged MLflow run with %s", tags)

    # Set the child run name to include the node name for clarity
    if node_name:
        child_run_name = f"{run_name} :: {node_name}" if run_name else node_name
        mlflow.MlflowClient().set_tag(active_run.info.run_id, "mlflow.runName", child_run_name)

on_pipeline_error(error, run_params, pipeline, catalog)

Tag the MLflow run with error details.

Parameters
Name Type Description Default
error Exception

The error that occurred.

required
run_params dict

Parameters passed to the run command.

required
pipeline Pipeline

Pipeline that failed.

required
catalog DataCatalog

Data catalog.

required
Source Code
Show/Hide source
@hook_impl
def on_pipeline_error(self, error, run_params, pipeline, catalog) -> None:
    """Tag the MLflow run with error details.

    Parameters
    ----------
    error : Exception
        The error that occurred.
    run_params : dict
        Parameters passed to the run command.
    pipeline : Pipeline
        Pipeline that failed.
    catalog : DataCatalog
        Data catalog.
    """
    if not _is_mlflow_integration_active():
        return

    try:
        import mlflow
    except ImportError:
        return

    active_run = mlflow.active_run()
    if active_run is None:
        return

    error_msg = str(error)[:250]
    mlflow.set_tag("kedro.error", error_msg)
    node_name = os.environ.get(KEDRO_AZUREML_MLFLOW_NODE_NAME, "")
    if node_name:
        mlflow.set_tag("kedro.failed_node", node_name)