Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/scheduling/layer_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ def deallocate(self, node: Node) -> None:
node.is_active = False
self._update_layer_loads_heap()

def reallocate(self, node: Node, start_layer: int, end_layer: int) -> None:
"""Reallocate a node to a specific layer range."""
self.deallocate(node)
self.allocate(node, start_layer, end_layer)

def declare(self, node: Node) -> None:
"""Declare a node to the allocator."""
if node.node_id not in self.node_id_to_node:
Expand Down
8 changes: 7 additions & 1 deletion src/scheduling/request_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,17 @@ class DynamicProgrammingRouting(RequestRoutingStrategy):
minimum-latency node sequence and total latency.
"""

def find_turning_points(self, nodes: List[Node], num_layers: int) -> List[Tuple[str, int, str]]:
@staticmethod
def find_turning_points(nodes: List[Node], num_layers: int) -> List[Tuple[str, int, str]]:
"""Find shard truncation points via layer-level DP.

DP state is (layer l, node i that hosts l). Node cost uses the node's
per-layer latency proxy; edge cost uses RTT between nodes.

This is a static method that can be called directly without creating an instance:
DynamicProgrammingRouting.find_turning_points(nodes, num_layers)

It can also be called via an instance, which will work due to Python's method resolution.
"""
if num_layers <= 0 or not nodes:
return []
Expand Down
95 changes: 70 additions & 25 deletions src/scheduling/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,72 +116,116 @@ def __init__(
pass

# Orchestration helpers
def bootstrap(self) -> bool:
"""Bootstrapping: first-time layer allocation and optional warm-up.
def bootstrap(self, *, clear_existing: bool = False, skip_warmup: bool = False) -> bool:
"""Bootstrapping:
This method can be used for both initial bootstrapping and global rebalancing.
When clear_existing=True, it first deallocates all existing allocations before
performing global allocation (rebalancing behavior). When clear_existing=False,
it performs allocation on top of existing state (initial bootstrapping behavior).

Returns True if a full pipeline was established; False otherwise.
Args:
clear_existing: If True, deallocate all existing allocations before reallocating.
This is used for global rebalancing. Default is False.
skip_warmup: If True, skip the warm-up and truncate step. Default is False.

Returns:
True if a full pipeline was established; False otherwise.
"""
if len(self.nodes) < self.min_nodes_bootstrapping:
# Check node count only for initial bootstrapping (not rebalancing)
if not clear_existing and len(self.nodes) < self.min_nodes_bootstrapping:
logger.debug(
f"Bootstrapping deferred: have {len(self.nodes)} nodes; need >= {self.min_nodes_bootstrapping}"
)
return False
logger.debug("Bootstrapping layer allocator")

# Clear existing allocations if this is a rebalance
if clear_existing:
logger.debug("Performing global rebalance (clearing existing allocations)")
self._bootstrapped = False
self._bootstrapped_event.clear()
for n in self.nodes:
if n.start_layer is not None and n.end_layer is not None:
self.layer_allocator.deallocate(n)
else:
logger.debug("Bootstrapping layer allocator")

# Perform global allocation
success = self.layer_allocator.global_allocation()
if not success:
logger.warning("Bootstrapping failed to produce a full pipeline")
logger.warning("Global allocation failed to produce a full pipeline")
return False

assignments = self.list_node_allocations()
logger.debug(f"Layer allocator assignments: {assignments}")

# Optional warm-up to find turning points and truncate node ranges
if self.request_warm_up_for_reshard > 0:
# Skip warmup for rebalancing scenarios (can be overridden with skip_warmup=False)
if not skip_warmup and self.request_warm_up_for_reshard > 0:
self._run_warmup_and_truncate()
assignments = self.list_node_allocations()
logger.debug(f"Layer allocator assignments after turn-point warm-up: {assignments}")

if not self.layer_allocator.has_full_pipeline():
logger.warning("Bootstrapping failed to produce a full pipeline")
return False

self._bootstrapped = True
self._bootstrapped_event.set()
logger.debug("Bootstrapping completed successfully; full pipeline established")
action = "rebalance" if clear_existing else "bootstrapping"
logger.debug(f"{action.capitalize()} completed successfully; full pipeline established")
return True

def list_node_allocations(self) -> List[Tuple[str, int, int]]:
"""List the allocations of all nodes."""
return self.layer_allocator.list_node_allocations()

# Warm-up and re-shard
def _run_warmup_and_truncate(self) -> None:
def _run_warmup_and_truncate(self, override_warmup_count: int = 0) -> None:
"""Run a brief warm-up to detect truncation points and shrink shards.

Uses layer-level DP turning points (node_id, layer_idx, kind):
- kind == "tail": drop [layer_idx, end) on that node
- kind == "head": drop [start, layer_idx) on that node

Note: Always uses DynamicProgrammingRouting for finding turning points,
regardless of the current request_router type, since turning points
detection requires layer-level DP analysis.

Args:
override_warmup_count: If > 0, use this value instead of request_warm_up_for_reshard.
Default is 0, which means use request_warm_up_for_reshard.
"""
nodes_list = list(self.nodes)
if not nodes_list:
return
num_layers = self.model_info.num_layers

# The number of warm-up requests can be used to repeat detection, but a
# single pass is sufficient with our DP model; we repeat to smooth noise.
warmup_count = (
override_warmup_count if override_warmup_count > 0 else self.request_warm_up_for_reshard
)

agg_turns: Dict[Tuple[str, int, str], int] = {}
for _ in range(self.request_warm_up_for_reshard):
turns = self.request_router.find_turning_points(nodes_list, num_layers)
for _ in range(warmup_count):
turns = DynamicProgrammingRouting.find_turning_points(nodes_list, num_layers)
for t in turns:
agg_turns[t] = agg_turns.get(t, 0) + 1

# Apply truncation for consistently observed turning points
# Note: Must use layer_allocator.allocate/deallocate to properly update
# internal state (node_allocation dict and layer_to_load)
for node_id, layer_idx, kind in agg_turns:
node = next((n for n in self.nodes if n.node_id == node_id), None)
if node is None or node.start_layer is None or node.end_layer is None:
continue
start, end = node.start_layer, node.end_layer
if kind == "tail":
if layer_idx < end:
node.set_layer_allocation(start, layer_idx)
self.layer_allocator.reallocate(node, start, layer_idx)
elif kind == "head":
if layer_idx > start:
node.set_layer_allocation(layer_idx, end)
self.layer_allocator.reallocate(node, layer_idx, end)

def update_node_info(
self,
Expand Down Expand Up @@ -316,19 +360,20 @@ def leave(self, node_id: str) -> None:
f"Mixed assignment detected ({manual_count} manual, {total_count - manual_count} automatic); skipping rebalance"
)
else:
# All nodes are automatic, proceed with rebalance
self._bootstrapped = False
self._bootstrapped_event.clear()
for n in self.nodes:
if n.start_layer is not None and n.end_layer is not None:
self.layer_allocator.deallocate(n)
success = self.layer_allocator.global_allocation()
if not success:
logger.warning("Global rebalance failed to produce a full pipeline")
# All nodes are automatic, try adjustment first, then rebalance if needed
if not self.layer_allocator.has_full_pipeline():
logger.debug(
"No full pipeline after node leave, attempting warmup and truncate"
)
self._run_warmup_and_truncate(override_warmup_count=1)
if not self.layer_allocator.has_full_pipeline():
self.bootstrap(clear_existing=True, skip_warmup=True)
else:
logger.debug(
"Pipeline recovered through warmup and truncate, skipping global rebalance"
)
else:
logger.debug("Global rebalance completed successfully")
self._bootstrapped = True
self._bootstrapped_event.set()
self.bootstrap(clear_existing=True, skip_warmup=True)

with self._node_count_cv:
self._node_count_cv.notify_all()
Expand Down
106 changes: 106 additions & 0 deletions tests/scheduler_tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,109 @@ def test_scheduler_single_node_leave_then_rejoin_reassigns_layers():
assert (
n1_rejoin.start_layer is not None and n1_rejoin.end_layer is not None
), "After re-join, single node should be assigned a full layer range"


def test_scheduler_three_nodes_sequential_join_leave_rejoin():
"""Test scheduler with 28-layer model, 3 nodes each capable of 22 layers.

Scenario:
- 28-layer model
- n1, n2, n3 all can host 22 layers
- min_nodes_bootstrapping=2
- n1, n2, n3 join sequentially
- n1 leaves and rejoins
- n2 leaves and rejoins
- n3 leaves and rejoins
"""
model = build_model_info(28)

# Create nodes that can each host 22 layers
# Calculation: 100GB can host 16 layers, so 22 layers need ~137.5GB
# Using 150GB to ensure capacity for 22 layers with some margin
n1 = build_node("n1", model, tflops=312.0, mem_gb=138.0, x=0, y=0)
n2 = build_node("n2", model, tflops=312.0, mem_gb=138.0, x=1, y=0)
n3 = build_node("n3", model, tflops=312.0, mem_gb=138.0, x=2, y=0)

# Verify nodes can host 22 layers
assert n1.get_decoder_layer_capacity() >= 22, "n1 should be able to host 22 layers"
assert n2.get_decoder_layer_capacity() >= 22, "n2 should be able to host 22 layers"
assert n3.get_decoder_layer_capacity() >= 22, "n3 should be able to host 22 layers"

# Initialize scheduler with min_nodes_bootstrapping=2, no nodes initially
sched = Scheduler(model, [], strategy="dp", min_nodes_bootstrapping=2)

# Step 1: n1 joins (not enough nodes yet)
sched.enqueue_join(n1)
sched._process_joins() # type: ignore[attr-defined]
assert len(sched.nodes) == 1
assert not sched.layer_allocator.has_full_pipeline()

# Step 2: n2 joins (now we have 2 nodes, should bootstrap)
sched.enqueue_join(n2)
sched._process_joins() # type: ignore[attr-defined]
set_rtt_from_coords(sched.nodes)
ok = sched.bootstrap()
assert ok, "Bootstrap should succeed with 2 nodes"
assert sched.layer_allocator.has_full_pipeline()

# Step 3: n3 joins (dynamic join after bootstrap)
sched.enqueue_join(n3)
sched._process_joins() # type: ignore[attr-defined]
set_rtt_from_coords(sched.nodes)
assert n3.start_layer is not None and n3.end_layer is not None
assert len(sched.nodes) == 3

# Step 4: n1 leaves and rejoins
n1_id = n1.node_id
sched.leave(n1_id)
assert n1 not in sched.nodes
assert len(sched.nodes) == 2
assert sched.layer_allocator.has_full_pipeline()

# Rejoin n1
n1_rejoin = build_node("n1", model, tflops=312.0, mem_gb=138.0, x=0, y=0)
sched.enqueue_join(n1_rejoin)
sched._process_joins() # type: ignore[attr-defined]
set_rtt_from_coords(sched.nodes)
assert n1_rejoin.start_layer is not None and n1_rejoin.end_layer is not None
assert len(sched.nodes) == 3
assert sched.layer_allocator.has_full_pipeline()

# Step 5: n2 leaves and rejoins
n2_id = n2.node_id
sched.leave(n2_id)
assert n2 not in sched.nodes
assert len(sched.nodes) == 2
assert sched.layer_allocator.has_full_pipeline()

# Rejoin n2
n2_rejoin = build_node("n2", model, tflops=312.0, mem_gb=138.0, x=1, y=0)
sched.enqueue_join(n2_rejoin)
sched._process_joins() # type: ignore[attr-defined]
set_rtt_from_coords(sched.nodes)
assert n2_rejoin.start_layer is not None and n2_rejoin.end_layer is not None
assert len(sched.nodes) == 3
assert sched.layer_allocator.has_full_pipeline()

# Step 6: n3 leaves and rejoins
n3_id = n3.node_id
sched.leave(n3_id)
assert n3 not in sched.nodes
assert len(sched.nodes) == 2
assert sched.layer_allocator.has_full_pipeline()

# Rejoin n3
n3_rejoin = build_node("n3", model, tflops=312.0, mem_gb=138.0, x=2, y=0)
sched.enqueue_join(n3_rejoin)
sched._process_joins() # type: ignore[attr-defined]
set_rtt_from_coords(sched.nodes)
assert n3_rejoin.start_layer is not None and n3_rejoin.end_layer is not None
assert len(sched.nodes) == 3
assert sched.layer_allocator.has_full_pipeline()

# Final verification: all nodes should have layer assignments
allocations = sched.list_node_allocations()
assert len(allocations) == 3, "All 3 nodes should have layer assignments"
# Verify full pipeline coverage
total_covered = sum(e - s for _, s, e in allocations)
assert total_covered >= model.num_layers, "All layers should be covered"