Skip to content

distributed_job

kedro_azureml_pipeline.distributed.distributed_job(framework, num_nodes, **kwargs)

Mark a Kedro node function for distributed execution.

Parameters

Name Type Description Default
framework Framework

Distributed framework (PyTorch, TensorFlow, or MPI).

required
num_nodes str or int

Number of compute nodes, or a params: reference.

required
**kwargs

Extra fields forwarded to DistributedNodeConfig.

{}

Returns

Type Description
callable

Decorator that attaches a DistributedNodeConfig to the wrapped function.

See Also

DistributedNodeConfig : Config attached by this decorator. Framework : Supported frameworks. AzureMLPipelineGenerator : Reads the attached config.

Source Code

Show/Hide source
def distributed_job(framework: Framework, num_nodes: str | int, **kwargs):
    """Mark a Kedro node function for distributed execution.

    Parameters
    ----------
    framework : Framework
        Distributed framework (PyTorch, TensorFlow, or MPI).
    num_nodes : str or int
        Number of compute nodes, or a ``params:`` reference.
    **kwargs
        Extra fields forwarded to ``DistributedNodeConfig``.

    Returns
    -------
    callable
        Decorator that attaches a ``DistributedNodeConfig`` to the
        wrapped function.

    See Also
    --------
    [DistributedNodeConfig][kedro_azureml_pipeline.distributed.config.DistributedNodeConfig] : Config attached by this decorator.
    [Framework][kedro_azureml_pipeline.distributed.config.Framework] : Supported frameworks.
    [AzureMLPipelineGenerator][kedro_azureml_pipeline.generator.AzureMLPipelineGenerator] : Reads the attached config.
    """

    def _decorator(func):
        """Attach distributed config to *func*."""
        config = DistributedNodeConfig(framework, num_nodes, **kwargs)
        setattr(
            func,
            DISTRIBUTED_CONFIG_FIELD,
            config,
        )

        @wraps(func)
        def wrapper(*args, **kws):
            """Forward call to the original function."""
            return func(*args, **kws)

        return wrapper

    return _decorator