How to Run Distributed Training¶
This guide shows how to mark Kedro nodes for distributed execution on Azure ML using the @distributed_job decorator.
Prerequisites¶
- A Kedro project with the plugin installed and configured (see Getting Started)
- An Azure ML compute cluster with multiple nodes available
- An Azure ML environment with your distributed training framework installed (PyTorch, TensorFlow, or MPI)
Decorate a node function¶
Import the decorator and mark the function you want to run as a distributed step:
from kedro_azureml_pipeline.distributed import distributed_job, Framework
@distributed_job(framework=Framework.PyTorch, num_nodes=4)
def train_model(X_train, y_train):
import torch.distributed as dist
dist.init_process_group("nccl")
# ... distributed training logic
return trained_model
Use the decorated function when registering your Kedro node:
from kedro.pipeline import node
train_node = node(
func=train_model,
inputs=["X_train", "y_train"],
outputs="trained_model",
name="train_model_node",
)
Supported frameworks¶
| Framework enum value | Distributed backend |
|---|---|
Framework.PyTorch |
PyTorch distributed (NCCL or Gloo) |
Framework.TensorFlow |
TensorFlow distributed strategy |
Framework.MPI |
MPI (Message Passing Interface) |
Set the number of processes per node¶
Use processes_per_node to launch multiple worker processes on each node:
@distributed_job(framework=Framework.PyTorch, num_nodes=4, processes_per_node=8)
def train_model(X_train, y_train):
...
Use a Kedro parameter for node count¶
Pass a params: reference to make the node count configurable per environment:
@distributed_job(framework=Framework.PyTorch, num_nodes="params:num_training_nodes")
def train_model(X_train, y_train):
...
Then set the value in conf/base/parameters.yml:
How it works¶
During local runs, @distributed_job has no effect and the function runs normally. During Azure ML runs, the pipeline generator wraps the step in a distributed job configuration. See the architecture overview for details on pipeline compilation.
Checking rank inside a node
Use is_distributed_master_node() to check whether the current process is rank 0. This is useful for logging or saving artifacts only from the master node:
Note
If your compute cluster has fewer nodes than num_nodes, Azure ML queues the job until enough nodes become available. The job will not fail immediately, but it may wait indefinitely if the cluster's maximum node count is lower than the requested count.
See also¶
- Architecture overview for how the pipeline generator translates Kedro nodes to Azure ML steps
distributed_jobAPI for the decorator parameter referenceFrameworkAPI for supported framework values