diff --git a/src/scheduling/layer_allocation.py b/src/scheduling/layer_allocation.py index cee8799c..d7dd1a4f 100644 --- a/src/scheduling/layer_allocation.py +++ b/src/scheduling/layer_allocation.py @@ -203,6 +203,11 @@ def deallocate(self, node: Node) -> None: node.is_active = False self._update_layer_loads_heap() + def reallocate(self, node: Node, start_layer: int, end_layer: int) -> None: + """Reallocate a node to a specific layer range.""" + self.deallocate(node) + self.allocate(node, start_layer, end_layer) + def declare(self, node: Node) -> None: """Declare a node to the allocator.""" if node.node_id not in self.node_id_to_node: diff --git a/src/scheduling/request_routing.py b/src/scheduling/request_routing.py index 9bd91e16..e4d5e09b 100644 --- a/src/scheduling/request_routing.py +++ b/src/scheduling/request_routing.py @@ -52,11 +52,17 @@ class DynamicProgrammingRouting(RequestRoutingStrategy): minimum-latency node sequence and total latency. """ - def find_turning_points(self, nodes: List[Node], num_layers: int) -> List[Tuple[str, int, str]]: + @staticmethod + def find_turning_points(nodes: List[Node], num_layers: int) -> List[Tuple[str, int, str]]: """Find shard truncation points via layer-level DP. DP state is (layer l, node i that hosts l). Node cost uses the node's per-layer latency proxy; edge cost uses RTT between nodes. + + This is a static method that can be called directly without creating an instance: + DynamicProgrammingRouting.find_turning_points(nodes, num_layers) + + It can also be called via an instance, which will work due to Python's method resolution. """ if num_layers <= 0 or not nodes: return [] diff --git a/src/scheduling/scheduler.py b/src/scheduling/scheduler.py index 8d9ab237..d6444718 100644 --- a/src/scheduling/scheduler.py +++ b/src/scheduling/scheduler.py @@ -116,25 +116,51 @@ def __init__( pass # Orchestration helpers - def bootstrap(self) -> bool: - """Bootstrapping: first-time layer allocation and optional warm-up. + def bootstrap(self, *, clear_existing: bool = False, skip_warmup: bool = False) -> bool: + """Bootstrapping: + This method can be used for both initial bootstrapping and global rebalancing. + When clear_existing=True, it first deallocates all existing allocations before + performing global allocation (rebalancing behavior). When clear_existing=False, + it performs allocation on top of existing state (initial bootstrapping behavior). - Returns True if a full pipeline was established; False otherwise. + Args: + clear_existing: If True, deallocate all existing allocations before reallocating. + This is used for global rebalancing. Default is False. + skip_warmup: If True, skip the warm-up and truncate step. Default is False. + + Returns: + True if a full pipeline was established; False otherwise. """ - if len(self.nodes) < self.min_nodes_bootstrapping: + # Check node count only for initial bootstrapping (not rebalancing) + if not clear_existing and len(self.nodes) < self.min_nodes_bootstrapping: logger.debug( f"Bootstrapping deferred: have {len(self.nodes)} nodes; need >= {self.min_nodes_bootstrapping}" ) return False - logger.debug("Bootstrapping layer allocator") + + # Clear existing allocations if this is a rebalance + if clear_existing: + logger.debug("Performing global rebalance (clearing existing allocations)") + self._bootstrapped = False + self._bootstrapped_event.clear() + for n in self.nodes: + if n.start_layer is not None and n.end_layer is not None: + self.layer_allocator.deallocate(n) + else: + logger.debug("Bootstrapping layer allocator") + + # Perform global allocation success = self.layer_allocator.global_allocation() if not success: - logger.warning("Bootstrapping failed to produce a full pipeline") + logger.warning("Global allocation failed to produce a full pipeline") return False + assignments = self.list_node_allocations() logger.debug(f"Layer allocator assignments: {assignments}") + # Optional warm-up to find turning points and truncate node ranges - if self.request_warm_up_for_reshard > 0: + # Skip warmup for rebalancing scenarios (can be overridden with skip_warmup=False) + if not skip_warmup and self.request_warm_up_for_reshard > 0: self._run_warmup_and_truncate() assignments = self.list_node_allocations() logger.debug(f"Layer allocator assignments after turn-point warm-up: {assignments}") @@ -142,9 +168,11 @@ def bootstrap(self) -> bool: if not self.layer_allocator.has_full_pipeline(): logger.warning("Bootstrapping failed to produce a full pipeline") return False + self._bootstrapped = True self._bootstrapped_event.set() - logger.debug("Bootstrapping completed successfully; full pipeline established") + action = "rebalance" if clear_existing else "bootstrapping" + logger.debug(f"{action.capitalize()} completed successfully; full pipeline established") return True def list_node_allocations(self) -> List[Tuple[str, int, int]]: @@ -152,25 +180,41 @@ def list_node_allocations(self) -> List[Tuple[str, int, int]]: return self.layer_allocator.list_node_allocations() # Warm-up and re-shard - def _run_warmup_and_truncate(self) -> None: + def _run_warmup_and_truncate(self, override_warmup_count: int = 0) -> None: """Run a brief warm-up to detect truncation points and shrink shards. Uses layer-level DP turning points (node_id, layer_idx, kind): - kind == "tail": drop [layer_idx, end) on that node - kind == "head": drop [start, layer_idx) on that node + + Note: Always uses DynamicProgrammingRouting for finding turning points, + regardless of the current request_router type, since turning points + detection requires layer-level DP analysis. + + Args: + override_warmup_count: If > 0, use this value instead of request_warm_up_for_reshard. + Default is 0, which means use request_warm_up_for_reshard. """ nodes_list = list(self.nodes) if not nodes_list: return num_layers = self.model_info.num_layers + # The number of warm-up requests can be used to repeat detection, but a # single pass is sufficient with our DP model; we repeat to smooth noise. + warmup_count = ( + override_warmup_count if override_warmup_count > 0 else self.request_warm_up_for_reshard + ) + agg_turns: Dict[Tuple[str, int, str], int] = {} - for _ in range(self.request_warm_up_for_reshard): - turns = self.request_router.find_turning_points(nodes_list, num_layers) + for _ in range(warmup_count): + turns = DynamicProgrammingRouting.find_turning_points(nodes_list, num_layers) for t in turns: agg_turns[t] = agg_turns.get(t, 0) + 1 + # Apply truncation for consistently observed turning points + # Note: Must use layer_allocator.allocate/deallocate to properly update + # internal state (node_allocation dict and layer_to_load) for node_id, layer_idx, kind in agg_turns: node = next((n for n in self.nodes if n.node_id == node_id), None) if node is None or node.start_layer is None or node.end_layer is None: @@ -178,10 +222,10 @@ def _run_warmup_and_truncate(self) -> None: start, end = node.start_layer, node.end_layer if kind == "tail": if layer_idx < end: - node.set_layer_allocation(start, layer_idx) + self.layer_allocator.reallocate(node, start, layer_idx) elif kind == "head": if layer_idx > start: - node.set_layer_allocation(layer_idx, end) + self.layer_allocator.reallocate(node, layer_idx, end) def update_node_info( self, @@ -316,19 +360,20 @@ def leave(self, node_id: str) -> None: f"Mixed assignment detected ({manual_count} manual, {total_count - manual_count} automatic); skipping rebalance" ) else: - # All nodes are automatic, proceed with rebalance - self._bootstrapped = False - self._bootstrapped_event.clear() - for n in self.nodes: - if n.start_layer is not None and n.end_layer is not None: - self.layer_allocator.deallocate(n) - success = self.layer_allocator.global_allocation() - if not success: - logger.warning("Global rebalance failed to produce a full pipeline") + # All nodes are automatic, try adjustment first, then rebalance if needed + if not self.layer_allocator.has_full_pipeline(): + logger.debug( + "No full pipeline after node leave, attempting warmup and truncate" + ) + self._run_warmup_and_truncate(override_warmup_count=1) + if not self.layer_allocator.has_full_pipeline(): + self.bootstrap(clear_existing=True, skip_warmup=True) + else: + logger.debug( + "Pipeline recovered through warmup and truncate, skipping global rebalance" + ) else: - logger.debug("Global rebalance completed successfully") - self._bootstrapped = True - self._bootstrapped_event.set() + self.bootstrap(clear_existing=True, skip_warmup=True) with self._node_count_cv: self._node_count_cv.notify_all() diff --git a/tests/scheduler_tests/test_scheduler.py b/tests/scheduler_tests/test_scheduler.py index 0695a672..7213bd85 100644 --- a/tests/scheduler_tests/test_scheduler.py +++ b/tests/scheduler_tests/test_scheduler.py @@ -133,3 +133,109 @@ def test_scheduler_single_node_leave_then_rejoin_reassigns_layers(): assert ( n1_rejoin.start_layer is not None and n1_rejoin.end_layer is not None ), "After re-join, single node should be assigned a full layer range" + + +def test_scheduler_three_nodes_sequential_join_leave_rejoin(): + """Test scheduler with 28-layer model, 3 nodes each capable of 22 layers. + + Scenario: + - 28-layer model + - n1, n2, n3 all can host 22 layers + - min_nodes_bootstrapping=2 + - n1, n2, n3 join sequentially + - n1 leaves and rejoins + - n2 leaves and rejoins + - n3 leaves and rejoins + """ + model = build_model_info(28) + + # Create nodes that can each host 22 layers + # Calculation: 100GB can host 16 layers, so 22 layers need ~137.5GB + # Using 150GB to ensure capacity for 22 layers with some margin + n1 = build_node("n1", model, tflops=312.0, mem_gb=138.0, x=0, y=0) + n2 = build_node("n2", model, tflops=312.0, mem_gb=138.0, x=1, y=0) + n3 = build_node("n3", model, tflops=312.0, mem_gb=138.0, x=2, y=0) + + # Verify nodes can host 22 layers + assert n1.get_decoder_layer_capacity() >= 22, "n1 should be able to host 22 layers" + assert n2.get_decoder_layer_capacity() >= 22, "n2 should be able to host 22 layers" + assert n3.get_decoder_layer_capacity() >= 22, "n3 should be able to host 22 layers" + + # Initialize scheduler with min_nodes_bootstrapping=2, no nodes initially + sched = Scheduler(model, [], strategy="dp", min_nodes_bootstrapping=2) + + # Step 1: n1 joins (not enough nodes yet) + sched.enqueue_join(n1) + sched._process_joins() # type: ignore[attr-defined] + assert len(sched.nodes) == 1 + assert not sched.layer_allocator.has_full_pipeline() + + # Step 2: n2 joins (now we have 2 nodes, should bootstrap) + sched.enqueue_join(n2) + sched._process_joins() # type: ignore[attr-defined] + set_rtt_from_coords(sched.nodes) + ok = sched.bootstrap() + assert ok, "Bootstrap should succeed with 2 nodes" + assert sched.layer_allocator.has_full_pipeline() + + # Step 3: n3 joins (dynamic join after bootstrap) + sched.enqueue_join(n3) + sched._process_joins() # type: ignore[attr-defined] + set_rtt_from_coords(sched.nodes) + assert n3.start_layer is not None and n3.end_layer is not None + assert len(sched.nodes) == 3 + + # Step 4: n1 leaves and rejoins + n1_id = n1.node_id + sched.leave(n1_id) + assert n1 not in sched.nodes + assert len(sched.nodes) == 2 + assert sched.layer_allocator.has_full_pipeline() + + # Rejoin n1 + n1_rejoin = build_node("n1", model, tflops=312.0, mem_gb=138.0, x=0, y=0) + sched.enqueue_join(n1_rejoin) + sched._process_joins() # type: ignore[attr-defined] + set_rtt_from_coords(sched.nodes) + assert n1_rejoin.start_layer is not None and n1_rejoin.end_layer is not None + assert len(sched.nodes) == 3 + assert sched.layer_allocator.has_full_pipeline() + + # Step 5: n2 leaves and rejoins + n2_id = n2.node_id + sched.leave(n2_id) + assert n2 not in sched.nodes + assert len(sched.nodes) == 2 + assert sched.layer_allocator.has_full_pipeline() + + # Rejoin n2 + n2_rejoin = build_node("n2", model, tflops=312.0, mem_gb=138.0, x=1, y=0) + sched.enqueue_join(n2_rejoin) + sched._process_joins() # type: ignore[attr-defined] + set_rtt_from_coords(sched.nodes) + assert n2_rejoin.start_layer is not None and n2_rejoin.end_layer is not None + assert len(sched.nodes) == 3 + assert sched.layer_allocator.has_full_pipeline() + + # Step 6: n3 leaves and rejoins + n3_id = n3.node_id + sched.leave(n3_id) + assert n3 not in sched.nodes + assert len(sched.nodes) == 2 + assert sched.layer_allocator.has_full_pipeline() + + # Rejoin n3 + n3_rejoin = build_node("n3", model, tflops=312.0, mem_gb=138.0, x=2, y=0) + sched.enqueue_join(n3_rejoin) + sched._process_joins() # type: ignore[attr-defined] + set_rtt_from_coords(sched.nodes) + assert n3_rejoin.start_layer is not None and n3_rejoin.end_layer is not None + assert len(sched.nodes) == 3 + assert sched.layer_allocator.has_full_pipeline() + + # Final verification: all nodes should have layer assignments + allocations = sched.list_node_allocations() + assert len(allocations) == 3, "All 3 nodes should have layer assignments" + # Verify full pipeline coverage + total_covered = sum(e - s for _, s, e in allocations) + assert total_covered >= model.num_layers, "All layers should be covered"