Skip to main content

Distributed Execution

The distributed execution module lets you run DAG tasks across multiple backends: threads, processes, Ray clusters, or Celery workers. A pluggable backend protocol makes it easy to integrate with any distributed computing framework.

For large DAGs, the PartitionedDAGExecutor splits the graph into partitions and assigns each partition to a different worker group for improved data locality and reduced communication overhead.

See the Distributed Execution guide for deployment patterns and backend selection advice.


DistributedExecutor

DistributedExecutor
class DistributedExecutor(
dag: DAG,
backend: DistributedBackend,
callbacks: ExecutionCallbacks | None = None,
fail_fast: bool = True,
enable_tracing: bool = False,
node_timeout: float | None = None,
)

An executor that dispatches tasks to a pluggable distributed backend. Supports the context manager protocol for automatic backend shutdown.

ParameterTypeDefaultDescription
dagDAGrequiredThe DAG whose topology drives execution order.
backendDistributedBackendrequiredThe backend implementation for dispatching tasks.
callbacksExecutionCallbacks | NoneNoneOptional lifecycle callbacks.
fail_fastboolTrueIf True, skip downstream nodes when any node fails.
enable_tracingboolFalseIf True, record a Chrome-compatible execution trace.
node_timeoutfloat | NoneNonePer-node timeout in seconds. Nodes exceeding this timeout are marked TIMED_OUT.

execute

DistributedExecutor.execute
def execute(
tasks: dict[str, Callable],
) -> DistributedExecutionResult

Execute tasks via the distributed backend.

ParameterTypeDefaultDescription
tasksdict[str, Callable]requiredMap of node names to callable tasks.

Returns: DistributedExecutionResult

Context Manager

The executor can be used as a context manager for automatic backend shutdown:

import dagron

dag = (
dagron.DAG.builder()
.add_node("fetch").add_node("process").add_node("store")
.add_edge("fetch", "process").add_edge("process", "store")
.build()
)

with dagron.DistributedExecutor(dag, backend=dagron.ThreadBackend(max_workers=4)) as executor:
result = executor.execute({
"fetch": lambda: "data",
"process": lambda: "processed",
"store": lambda: "stored",
})

print(result.succeeded) # 3
# Backend is automatically shut down on exit

DistributedExecutionResult

DistributedExecutionResult
class DistributedExecutionResult(
execution_result: ExecutionResult,
backend_name: str,
dispatch_info: dict[str, Any],
)

The result of a distributed execution. Wraps the standard ExecutionResult with backend-specific metadata.

PropertyTypeDescription
execution_resultExecutionResultThe underlying execution result with per-node details.
backend_namestrName of the backend used (e.g., "thread", "ray", "celery").
dispatch_infodict[str, Any]Backend-specific dispatch metadata (worker IDs, queue names, etc.).
print(f"Backend: {result.backend_name}")
print(f"Succeeded: {result.execution_result.succeeded}")
print(f"Dispatch info: {result.dispatch_info}")

DistributedBackend Protocol

DistributedBackend
class DistributedBackend(Protocol):
@property
def name(self) -> str: ...

def submit(
self,
task: Callable,
node_name: str,
) -> Any: ...

def result(
self,
future: Any,
timeout: float | None = None,
) -> Any: ...

def shutdown(self) -> None: ...

The protocol that all distributed backends must implement. You can create custom backends by implementing these four members.

MethodDescription
nameA human-readable backend name.
submit(task, node_name)Submit a task for execution. Returns a future-like object.
result(future, timeout)Block until the future completes and return its result. Raises on timeout.
shutdown()Shut down the backend and release all resources.
class MyCustomBackend:
@property
def name(self) -> str:
return "custom"

def submit(self, task, node_name):
# dispatch to your infrastructure
return my_cluster.submit(task)

def result(self, future, timeout=None):
return future.get(timeout=timeout)

def shutdown(self):
my_cluster.close()

Built-in Backends

ThreadBackend

ThreadBackend
class ThreadBackend(
max_workers: int | None = None,
)

A backend that dispatches tasks to a concurrent.futures.ThreadPoolExecutor. Best for I/O-bound tasks.

ParameterTypeDefaultDescription
max_workersint | NoneNoneMaximum number of worker threads. None uses os.cpu_count().
backend = dagron.ThreadBackend(max_workers=8)

MultiprocessingBackend

MultiprocessingBackend
class MultiprocessingBackend(
max_workers: int | None = None,
)

A backend that dispatches tasks to a concurrent.futures.ProcessPoolExecutor. Best for CPU-bound tasks. Tasks must be picklable.

ParameterTypeDefaultDescription
max_workersint | NoneNoneMaximum number of worker processes. None uses os.cpu_count().
backend = dagron.MultiprocessingBackend(max_workers=4)
caution

Tasks submitted to MultiprocessingBackend must be picklable. Lambdas and closures will fail. Use module-level functions instead.

RayBackend

RayBackend
class RayBackend(
address: str | None = None,
num_cpus: int | None = None,
num_gpus: int | None = None,
runtime_env: dict | None = None,
)

A backend that dispatches tasks to a Ray cluster. Requires ray to be installed (pip install dagron[ray]).

ParameterTypeDefaultDescription
addressstr | NoneNoneRay cluster address. None starts a local cluster.
num_cpusint | NoneNoneNumber of CPUs to request per task.
num_gpusint | NoneNoneNumber of GPUs to request per task.
runtime_envdict | NoneNoneRay runtime environment (pip packages, env vars, etc.).
backend = dagron.RayBackend(
address="ray://cluster:10001",
num_cpus=2,
num_gpus=1,
)

with dagron.DistributedExecutor(dag, backend=backend) as executor:
result = executor.execute(tasks)

CeleryBackend

CeleryBackend
class CeleryBackend(
app: Any = None,
broker: str | None = None,
backend_url: str | None = None,
queue: str = "default",
)

A backend that dispatches tasks to Celery workers. Requires celery to be installed (pip install dagron[celery]).

ParameterTypeDefaultDescription
appAnyNoneAn existing Celery app instance. If None, a new one is created.
brokerstr | NoneNoneCelery broker URL (e.g., 'redis://localhost:6379/0').
backend_urlstr | NoneNoneCelery result backend URL.
queuestr"default"The Celery queue name for task dispatch.
backend = dagron.CeleryBackend(
broker="redis://localhost:6379/0",
backend_url="redis://localhost:6379/1",
queue="dagron_tasks",
)

with dagron.DistributedExecutor(dag, backend=backend) as executor:
result = executor.execute(tasks)

PartitionedDAGExecutor

PartitionedDAGExecutor
class PartitionedDAGExecutor(
dag: DAG,
k: int,
strategy: str = "balanced",
costs: dict[str, float] | None = None,
max_workers: int | None = None,
callbacks: ExecutionCallbacks | None = None,
fail_fast: bool = True,
enable_tracing: bool = False,
)

An executor that partitions the DAG into k groups and executes each partition with a dedicated worker pool. This reduces inter-partition communication and improves data locality for large DAGs.

ParameterTypeDefaultDescription
dagDAGrequiredThe DAG to partition and execute.
kintrequiredNumber of partitions.
strategystr"balanced"Partitioning strategy: "level_based", "balanced", or "communication_min".
costsdict[str, float] | NoneNonePer-node cost estimates for partitioning heuristics.
max_workersint | NoneNoneTotal number of workers across all partitions.
callbacksExecutionCallbacks | NoneNoneOptional lifecycle callbacks.
fail_fastboolTrueIf True, skip downstream nodes when any node fails.
enable_tracingboolFalseIf True, record a Chrome-compatible execution trace.

execute

PartitionedDAGExecutor.execute
def execute(
tasks: dict[str, Callable],
) -> ExecutionResult

Partition the DAG and execute tasks.

ParameterTypeDefaultDescription
tasksdict[str, Callable]requiredMap of node names to callable tasks.

Returns: ExecutionResult

import dagron

# A large DAG with many nodes
dag = dagron.DAG.builder()
for i in range(100):
dag = dag.add_node(f"node_{i}")
for i in range(99):
dag = dag.add_edge(f"node_{i}", f"node_{i+1}")
dag = dag.build()

tasks = {f"node_{i}": lambda i=i: f"result_{i}" for i in range(100)}

executor = dagron.PartitionedDAGExecutor(
dag,
k=4,
strategy="balanced",
max_workers=8,
)

result = executor.execute(tasks)
print(f"Succeeded: {result.succeeded}") # 100

Strategies

StrategyDescription
"level_based"Assign nodes to partitions based on their topological level. Simple and fast.
"balanced"Balance node costs across partitions. Good general-purpose strategy.
"communication_min"Minimize cross-partition edges using Kernighan-Lin refinement. Best for data-intensive pipelines.

These map to DAG.partition_level_based(), DAG.partition_balanced(), and DAG.partition_communication_min() respectively. See DAG partitioning for the underlying algorithms.


Complete Example: Ray Cluster

A complete distributed ML training pipeline running on Ray:

import dagron

dag = (
dagron.DAG.builder()
.add_node("load_data")
.add_node("preprocess")
.add_node("train_xgb")
.add_node("train_nn")
.add_node("ensemble")
.add_node("evaluate")
.add_edge("load_data", "preprocess")
.add_edge("preprocess", "train_xgb")
.add_edge("preprocess", "train_nn")
.add_edge("train_xgb", "ensemble")
.add_edge("train_nn", "ensemble")
.add_edge("ensemble", "evaluate")
.build()
)

def load_data():
return "loaded 1M rows"

def preprocess():
return "preprocessed features"

def train_xgb():
import time; time.sleep(5)
return {"model": "xgb", "auc": 0.92}

def train_nn():
import time; time.sleep(10)
return {"model": "nn", "auc": 0.94}

def ensemble():
return {"model": "ensemble", "auc": 0.96}

def evaluate():
return "evaluation report saved"

tasks = {
"load_data": load_data,
"preprocess": preprocess,
"train_xgb": train_xgb,
"train_nn": train_nn,
"ensemble": ensemble,
"evaluate": evaluate,
}

backend = dagron.RayBackend(num_cpus=2, num_gpus=1)

with dagron.DistributedExecutor(
dag,
backend=backend,
enable_tracing=True,
node_timeout=300,
callbacks=dagron.ExecutionCallbacks(
on_start=lambda n: print(f"[{backend.name}] Starting {n}"),
on_complete=lambda n, r: print(f"[{backend.name}] Completed {n}"),
),
) as executor:
result = executor.execute(tasks)

print(f"\nBackend: {result.backend_name}")
print(f"Succeeded: {result.execution_result.succeeded}")
print(f"Duration: {result.execution_result.total_duration_seconds:.1f}s")