Skip to main content

Contracts

The contracts module provides build-time type checking for DAG edges. You can declare the expected input and output types for each node, and the validator checks that producer output types are compatible with consumer input types across every edge. This catches type mismatches before execution.

For pipelines built with the @task decorator, contracts can be automatically extracted from type annotations.

from dagron.contracts import (
NodeContract,
ContractValidator,
ContractViolation,
extract_contracts,
validate_contracts,
)

NodeContract

NodeContract
@dataclass(frozen=True)
class NodeContract:
inputs: dict[str, type] = field(default_factory=dict)
output: type = object

Type contract for a single node's inputs and outputs. This is a frozen dataclass, so instances are hashable and immutable after creation.

ParameterTypeDefaultDescription
inputsdict[str, type]{}Mapping of dependency name to expected input type. The key is the name of the upstream node, and the value is the type that this node expects to receive from that upstream node.
outputtypeobjectThe declared output type of this node. Using object (the default) acts as a wildcard, equivalent to Any.
from dagron.contracts import NodeContract

# A node that takes a list from 'extract' and produces a dict
transform_contract = NodeContract(
inputs={"extract": list},
output=dict,
)

# A node with no type constraints (wildcard)
passthrough = NodeContract()

ContractViolation

ContractViolation
@dataclass(frozen=True)
class ContractViolation:
from_node: str
to_node: str
message: str

A single type-contract violation detected during validation. Frozen dataclass, so instances are immutable and hashable.

ParameterTypeDefaultDescription
from_nodestrrequiredThe upstream (producer) node name.
to_nodestrrequiredThe downstream (consumer) node name.
messagestrrequiredHuman-readable description of the type mismatch.
for violation in violations:
print(f"Edge {violation.from_node} -> {violation.to_node}: {violation.message}")

ContractValidator

ContractValidator
class ContractValidator:
def __init__(
self,
dag: DAG,
contracts: dict[str, NodeContract],
) -> None: ...

Validates type contracts across DAG edges. For every edge (u, v) in the DAG, the validator checks that the output type of u is compatible with the expected input type declared by v for dependency u. Compatibility is determined via issubclass. The object type acts as a wildcard.

ParameterTypeDefaultDescription
dagDAGrequiredThe DAG to validate.
contractsdict[str, NodeContract]requiredMapping of node names to their type contracts.

Methods


ContractValidator.validate

ContractValidator.validate
def validate(self) -> list[ContractViolation]

Run validation and return all detected violations. An empty list means all contracts are satisfied.

Returns: list[ContractViolation] -- List of type mismatches found across DAG edges.

import dagron
from dagron.contracts import NodeContract, ContractValidator

dag = (
dagron.DAG.builder()
.add_edge("extract", "transform")
.add_edge("transform", "load")
.build()
)

contracts = {
"extract": NodeContract(output=list),
"transform": NodeContract(inputs={"extract": dict}, output=str),
"load": NodeContract(inputs={"transform": str}),
}

validator = ContractValidator(dag, contracts)
violations = validator.validate()

for v in violations:
print(v.message)
# Type mismatch on edge extract -> transform: producer outputs list,
# but consumer expects dict

Compatibility rules

The validator uses issubclass to check compatibility:

  • list is compatible with list (exact match).
  • bool is compatible with int (subclass relationship).
  • object is always compatible (wildcard / Any equivalent).
  • Generic type aliases (e.g., list[int]) are treated as compatible if issubclass raises TypeError.

extract_contracts

extract_contracts
def extract_contracts(
pipeline: Pipeline,
) -> dict[str, NodeContract]

Auto-extract NodeContract instances from a Pipeline's @task functions. Uses typing.get_type_hints() to read input parameter types and return annotations from each decorated function.

ParameterTypeDefaultDescription
pipelinePipelinerequiredA Pipeline instance whose tasks have type annotations.

Returns: dict[str, NodeContract] -- Mapping of task names to their extracted contracts.

from dagron import Pipeline, task
from dagron.contracts import extract_contracts

@task
def extract() -> list:
return [1, 2, 3]

@task
def transform(extract: list) -> dict:
return {"data": extract}

@task
def load(transform: dict) -> str:
return "done"

pipeline = Pipeline(tasks=[extract, transform, load])
contracts = extract_contracts(pipeline)

print(contracts["extract"].output) # <class 'list'>
print(contracts["transform"].inputs) # {'extract': <class 'list'>}
print(contracts["transform"].output) # <class 'dict'>

validate_contracts

validate_contracts
def validate_contracts(
pipeline: Pipeline,
extra_contracts: dict[str, NodeContract] | None = None,
) -> list[ContractViolation]

Convenience function that extracts contracts from a pipeline and validates them in a single call. Optionally merges manually specified contracts that override the auto-extracted ones.

ParameterTypeDefaultDescription
pipelinePipelinerequiredA Pipeline instance to validate.
extra_contractsdict[str, NodeContract] | NoneNoneOptional manually specified contracts that override auto-extracted ones for specific nodes.

Returns: list[ContractViolation] -- List of violations. Empty means all contracts are satisfied.

from dagron import Pipeline, task
from dagron.contracts import validate_contracts, NodeContract

@task
def extract() -> list:
return [1, 2, 3]

@task
def transform(extract: dict) -> str: # Bug: expects dict, but extract returns list
return str(extract)

pipeline = Pipeline(tasks=[extract, transform])
violations = validate_contracts(pipeline)

if violations:
for v in violations:
print(f"Contract violation: {v.message}")
# Contract violation: Type mismatch on edge extract -> transform:
# producer outputs list, but consumer expects dict

Overriding extracted contracts

Sometimes auto-extraction is not enough -- for example, when functions lack type annotations or when you want stricter constraints:

from dagron.contracts import validate_contracts, NodeContract

overrides = {
"transform": NodeContract(
inputs={"extract": list},
output=dict,
),
}

violations = validate_contracts(pipeline, extra_contracts=overrides)

Complete example

import dagron
from dagron import Pipeline, task
from dagron.contracts import (
NodeContract,
ContractValidator,
extract_contracts,
validate_contracts,
)

# Define a typed pipeline
@task
def fetch_users() -> list:
return [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]

@task
def normalize(fetch_users: list) -> list:
return [{"id": u["id"], "name": u["name"].upper()} for u in fetch_users]

@task
def store(normalize: list) -> int:
return len(normalize)

pipeline = Pipeline(tasks=[fetch_users, normalize, store], name="users")

# Validate contracts automatically
violations = validate_contracts(pipeline)
assert not violations, f"Contract violations: {violations}"

# Or extract and inspect contracts manually
contracts = extract_contracts(pipeline)
for name, contract in contracts.items():
print(f"{name}: inputs={contract.inputs}, output={contract.output}")

# Manual validation against an arbitrary DAG
dag = (
dagron.DAG.builder()
.add_edge("source", "sink")
.build()
)

manual_contracts = {
"source": NodeContract(output=str),
"sink": NodeContract(inputs={"source": int}), # Mismatch!
}

validator = ContractValidator(dag, manual_contracts)
for v in validator.validate():
print(v.message)
# Type mismatch on edge source -> sink: producer outputs str,
# but consumer expects int

See also