Skip to main content

Distributed Execution

dagron's distributed execution system lets you run DAG nodes across different concurrency and distribution primitives -- from local thread pools to Ray clusters and Celery workers -- using a single, unified API. A pluggable DistributedBackend protocol abstracts away the transport, so you can switch from threads to Ray by changing one line.

For large DAGs, the PartitionedDAGExecutor splits the graph into partitions and executes each partition as a unit, minimizing cross-partition communication.

A DAG split into 3 partitions. Each partition runs on a different worker.


Architecture Overview

There are two main approaches to distributed execution:

ExecutorApproachBest for
DistributedExecutorDispatches individual nodes to a backend by topological level.Fine-grained distribution where each node runs independently.
PartitionedDAGExecutorSplits the DAG into k partitions, executes each partition as a sub-DAG.Coarse-grained distribution that minimizes serialization overhead.

Both use the DistributedBackend protocol for the actual task dispatch.


DistributedBackend Protocol

All backends implement three methods:

class DistributedBackend(Protocol):
@property
def name(self) -> str: ...

def submit(self, fn, *args, **kwargs) -> Any:
"""Submit a callable for execution. Returns a future."""
...

def result(self, future, timeout=None) -> Any:
"""Retrieve the result of a submitted task."""
...

def shutdown(self, wait=True) -> None:
"""Shut down the backend and release resources."""
...

dagron ships with four backends:

BackendModuleUse case
ThreadBackenddagron.execution.backends.threadI/O-bound tasks, testing, development.
MultiprocessingBackenddagron.execution.backends.multiprocessingCPU-bound tasks on a single machine.
RayBackenddagron.execution.backends.rayMulti-machine clusters. Requires pip install dagron[ray].
CeleryBackenddagron.execution.backends.celeryExisting Celery infrastructure. Requires pip install dagron[celery].

ThreadBackend

The simplest backend, using Python's ThreadPoolExecutor. Good for I/O-bound workloads (API calls, database queries, file downloads):

import dagron
from dagron.execution.distributed_executor import DistributedExecutor
from dagron.execution.backends.thread import ThreadBackend

dag = (
dagron.DAG.builder()
.add_node("fetch_users")
.add_node("fetch_orders")
.add_node("join")
.add_edge("fetch_users", "join")
.add_edge("fetch_orders", "join")
.build()
)

backend = ThreadBackend(max_workers=8)

with DistributedExecutor(dag, backend) as executor:
result = executor.execute({
"fetch_users": lambda: fetch_from_api("/users"),
"fetch_orders": lambda: fetch_from_api("/orders"),
"join": lambda: merge_data(),
})

print(f"Backend: {result.backend_name}") # "thread"
print(f"Succeeded: {result.execution_result.succeeded}")

The with statement ensures backend.shutdown() is called when execution completes.


MultiprocessingBackend

Bypasses the GIL for CPU-bound workloads by dispatching tasks to separate processes:

from dagron.execution.backends.multiprocessing import MultiprocessingBackend

backend = MultiprocessingBackend(max_workers=4)

with DistributedExecutor(dag, backend) as executor:
result = executor.execute(tasks)
caution

Tasks must be picklable when using MultiprocessingBackend. Lambda functions and closures cannot be pickled. Use module-level functions instead.

# This works:
def compute_features():
return heavy_computation()

tasks = {"features": compute_features}

# This does NOT work with multiprocessing:
tasks = {"features": lambda: heavy_computation()}

RayBackend

Distribute tasks across a Ray cluster for true multi-machine parallelism:

from dagron.execution.backends.ray import RayBackend

# Initialize Ray (or connect to an existing cluster)
backend = RayBackend(num_cpus=16)

with DistributedExecutor(dag, backend, node_timeout=300) as executor:
result = executor.execute(tasks)

Ray must be installed separately:

pip install dagron[ray]

If Ray is already initialized (e.g., you called ray.init() elsewhere), RayBackend detects this and reuses the existing session.

Ray Cluster Example

import ray
from dagron.execution.backends.ray import RayBackend

# Connect to a remote cluster
ray.init(address="ray://cluster-head:10001")

backend = RayBackend() # uses the existing Ray session

with DistributedExecutor(dag, backend) as executor:
result = executor.execute({
"train_model_a": lambda: train_on_gpu("model_a"),
"train_model_b": lambda: train_on_gpu("model_b"),
"ensemble": lambda: combine_models(),
})

CeleryBackend

Integrate with existing Celery infrastructure for message-broker-based distribution:

from celery import Celery
from dagron.execution.backends.celery import CeleryBackend

app = Celery("dagron_tasks", broker="redis://localhost:6379")

backend = CeleryBackend(app=app, queue="dagron")

with DistributedExecutor(dag, backend) as executor:
result = executor.execute(tasks)

Celery must be installed separately:

pip install dagron[celery]

The queue parameter routes all dagron tasks to a specific Celery queue, keeping them separate from your other Celery tasks.


DistributedExecutor

The DistributedExecutor dispatches nodes by topological level. All nodes in a level are submitted to the backend concurrently, and results are collected before advancing to the next level.

from dagron.execution.distributed_executor import DistributedExecutor

executor = DistributedExecutor(
dag,
backend=backend,
fail_fast=True, # skip downstream on failure
enable_tracing=True, # record execution trace
node_timeout=60.0, # per-node timeout in seconds
)

result = executor.execute(tasks)

DistributedExecutionResult

The result contains the standard ExecutionResult plus distributed metadata:

result = executor.execute(tasks)

# Standard execution stats
er = result.execution_result
print(f"Succeeded: {er.succeeded}, Failed: {er.failed}")
print(f"Total time: {er.total_duration_seconds:.1f}s")

# Distributed metadata
print(f"Backend: {result.backend_name}")
print(f"Dispatch info: {result.dispatch_info}")
# e.g. {"fetch_users": {"backend": "ray"}, ...}

Context Manager

DistributedExecutor supports context-manager usage for automatic cleanup:

with DistributedExecutor(dag, backend) as executor:
result = executor.execute(tasks)
# backend.shutdown(wait=True) is called automatically

Node Timeout

Set node_timeout to fail nodes that take too long:

executor = DistributedExecutor(dag, backend, node_timeout=30.0)
result = executor.execute(tasks)

# Check for timed-out nodes
print(f"Timed out: {result.execution_result.timed_out}")

Timed-out nodes are treated as failures and trigger fail-fast behavior for downstream nodes.


PartitionedDAGExecutor

For large DAGs, dispatching every node individually to a remote backend can create excessive serialization overhead. The PartitionedDAGExecutor solves this by splitting the DAG into k partitions and executing each partition as a sub-DAG:

from dagron.execution.distributed import PartitionedDAGExecutor

executor = PartitionedDAGExecutor(
dag,
k=4, # target number of partitions
strategy="level_based", # partitioning strategy
max_workers=8, # workers per partition
fail_fast=True,
)

result = executor.execute(tasks)

Partitioning Strategies

StrategyDescriptionBest for
"level_based"Assigns nodes to partitions based on their topological level.Balanced, predictable partitions.
"balanced"Distributes nodes to minimize the maximum partition cost.Cost-aware balancing when node costs vary widely.
"communication_min"Minimizes cross-partition edges (Kernighan-Lin style).Minimizing serialization overhead between partitions.

Level-Based Partitioning

Groups nodes by topological level and distributes levels across k partitions:

executor = PartitionedDAGExecutor(dag, k=3, strategy="level_based")

Level-based partitioning: each level maps to a partition.

Balanced Partitioning

When nodes have very different execution costs, use balanced partitioning:

costs = {
"extract": 5.0,
"heavy_transform": 120.0,
"light_transform": 2.0,
"load": 10.0,
}

executor = PartitionedDAGExecutor(
dag,
k=2,
strategy="balanced",
costs=costs,
)

Communication-Minimizing Partitioning

Minimizes the number of edges that cross partition boundaries:

executor = PartitionedDAGExecutor(
dag,
k=3,
strategy="communication_min",
max_iterations=20, # Kernighan-Lin iterations
max_imbalance=0.3, # allow 30% size imbalance
)

The max_imbalance parameter controls the trade-off between partition balance and communication minimization. A value of 0.0 requires perfectly balanced partitions; 0.3 allows 30% deviation.


Choosing Between Executors

ScenarioRecommended Executor
Small DAG, I/O-bound tasksDistributedExecutor + ThreadBackend
Small DAG, CPU-bound tasksDistributedExecutor + MultiprocessingBackend
Large DAG, multi-machine clusterPartitionedDAGExecutor with "communication_min"
Existing Celery infrastructureDistributedExecutor + CeleryBackend
GPU clusterDistributedExecutor + RayBackend

Writing a Custom Backend

Implement the DistributedBackend protocol to integrate with any execution system:

from dagron.execution.backends.base import DistributedBackend

class DaskBackend:
"""Example backend using Dask distributed."""

def __init__(self, scheduler_address: str):
from dask.distributed import Client
self._client = Client(scheduler_address)

@property
def name(self) -> str:
return "dask"

def submit(self, fn, *args, **kwargs):
return self._client.submit(fn, *args, **kwargs)

def result(self, future, timeout=None):
return future.result(timeout=timeout)

def shutdown(self, wait=True):
self._client.close()

# Usage
backend = DaskBackend("tcp://scheduler:8786")
executor = DistributedExecutor(dag, backend)

Combining with Other Features

Distributed + Tracing

Enable tracing to see per-node timing across distributed workers:

executor = DistributedExecutor(dag, backend, enable_tracing=True)
result = executor.execute(tasks)

trace = result.execution_result.trace
if trace:
trace.to_chrome_json("distributed_trace.json")

Distributed + Fail-Fast

executor = DistributedExecutor(dag, backend, fail_fast=True)

When a node fails, all downstream nodes are skipped, even across different topological levels.

Partitioned + Cost Estimates

Provide cost estimates for better partitioning:

costs = {node: estimate_cost(node) for node in dag.node_names()}

executor = PartitionedDAGExecutor(
dag,
k=4,
strategy="balanced",
costs=costs,
)

Complete Example: Ray Cluster Training

import dagron
from dagron.execution.distributed_executor import DistributedExecutor
from dagron.execution.backends.ray import RayBackend

# Build a training pipeline
dag = (
dagron.DAG.builder()
.add_node("load_data")
.add_node("preprocess")
.add_node("train_model_a")
.add_node("train_model_b")
.add_node("train_model_c")
.add_node("ensemble")
.add_node("evaluate")
.add_edge("load_data", "preprocess")
.add_edge("preprocess", "train_model_a")
.add_edge("preprocess", "train_model_b")
.add_edge("preprocess", "train_model_c")
.add_edge("train_model_a", "ensemble")
.add_edge("train_model_b", "ensemble")
.add_edge("train_model_c", "ensemble")
.add_edge("ensemble", "evaluate")
.build()
)

def load_data():
return load_dataset("imagenet")

def preprocess():
return normalize_images()

def train_model_a():
return train("resnet50", epochs=10)

def train_model_b():
return train("vgg16", epochs=10)

def train_model_c():
return train("efficientnet", epochs=10)

def ensemble():
return combine_predictions()

def evaluate():
return compute_metrics()

tasks = {
"load_data": load_data,
"preprocess": preprocess,
"train_model_a": train_model_a,
"train_model_b": train_model_b,
"train_model_c": train_model_c,
"ensemble": ensemble,
"evaluate": evaluate,
}

# Dispatch to Ray -- models train in parallel on different machines
backend = RayBackend(num_cpus=32)

with DistributedExecutor(dag, backend, node_timeout=3600) as executor:
result = executor.execute(tasks)

er = result.execution_result
print(f"Succeeded: {er.succeeded}/{er.succeeded + er.failed}")
print(f"Total time: {er.total_duration_seconds:.0f}s")

Best Practices

  1. Start with ThreadBackend for development. Switch to RayBackend or CeleryBackend for production.

  2. Use PartitionedDAGExecutor for large DAGs. When your DAG has hundreds of nodes, per-node dispatch overhead adds up. Partitioning reduces it.

  3. Provide cost estimates. The balanced and communication_min strategies produce much better partitions when they know how long each node takes.

  4. Set node_timeout. Prevent runaway tasks from blocking the entire pipeline.

  5. Use the context manager. Always use with DistributedExecutor(...) as executor: to ensure proper cleanup.

  6. Avoid lambdas with multiprocessing. Module-level functions are required for pickling.