From c9efa008f16172d931fd3b0df3be55c99acc3f60 Mon Sep 17 00:00:00 2001 From: jason Date: Fri, 7 Nov 2025 23:29:07 +0800 Subject: [PATCH 1/4] warmup before global rebalance --- src/scheduling/scheduler.py | 76 ++++++++++++++++++++++++++++--------- 1 file changed, 59 insertions(+), 17 deletions(-) diff --git a/src/scheduling/scheduler.py b/src/scheduling/scheduler.py index 8d9ab23..3a7b0d2 100644 --- a/src/scheduling/scheduler.py +++ b/src/scheduling/scheduler.py @@ -152,25 +152,48 @@ def list_node_allocations(self) -> List[Tuple[str, int, int]]: return self.layer_allocator.list_node_allocations() # Warm-up and re-shard - def _run_warmup_and_truncate(self) -> None: + def _run_warmup_and_truncate(self, override_warmup_count: int = 0) -> None: """Run a brief warm-up to detect truncation points and shrink shards. Uses layer-level DP turning points (node_id, layer_idx, kind): - kind == "tail": drop [layer_idx, end) on that node - kind == "head": drop [start, layer_idx) on that node + + Note: Always uses DynamicProgrammingRouting for finding turning points, + regardless of the current request_router type, since turning points + detection requires layer-level DP analysis. + + Args: + override_warmup_count: If > 0, use this value instead of request_warm_up_for_reshard. + Default is 0, which means use request_warm_up_for_reshard. """ nodes_list = list(self.nodes) if not nodes_list: return num_layers = self.model_info.num_layers + # The number of warm-up requests can be used to repeat detection, but a # single pass is sufficient with our DP model; we repeat to smooth noise. + warmup_count = ( + override_warmup_count if override_warmup_count > 0 else self.request_warm_up_for_reshard + ) + + # Always use DP router for finding turning points, regardless of current router type + # This is because turning points detection requires layer-level DP analysis + dp_router = DynamicProgrammingRouting() + agg_turns: Dict[Tuple[str, int, str], int] = {} - for _ in range(self.request_warm_up_for_reshard): - turns = self.request_router.find_turning_points(nodes_list, num_layers) + for _ in range(warmup_count): + turns = dp_router.find_turning_points(nodes_list, num_layers) for t in turns: agg_turns[t] = agg_turns.get(t, 0) + 1 + + if not agg_turns: + return + # Apply truncation for consistently observed turning points + # Note: Must use layer_allocator.allocate/deallocate to properly update + # internal state (node_allocation dict and layer_to_load) for node_id, layer_idx, kind in agg_turns: node = next((n for n in self.nodes if n.node_id == node_id), None) if node is None or node.start_layer is None or node.end_layer is None: @@ -178,10 +201,12 @@ def _run_warmup_and_truncate(self) -> None: start, end = node.start_layer, node.end_layer if kind == "tail": if layer_idx < end: - node.set_layer_allocation(start, layer_idx) + self.layer_allocator.deallocate(node) + self.layer_allocator.allocate(node, start, layer_idx) elif kind == "head": if layer_idx > start: - node.set_layer_allocation(layer_idx, end) + self.layer_allocator.deallocate(node) + self.layer_allocator.allocate(node, layer_idx, end) def update_node_info( self, @@ -291,6 +316,22 @@ def join(self, node: Node, bootstrap: bool = False) -> None: with self._node_count_cv: self._node_count_cv.notify_all() + def _perform_global_rebalance(self) -> None: + """Perform global rebalancing: deallocate all nodes and reallocate.""" + logger.debug("Performing global rebalance") + self._bootstrapped = False + self._bootstrapped_event.clear() + for n in self.nodes: + if n.start_layer is not None and n.end_layer is not None: + self.layer_allocator.deallocate(n) + success = self.layer_allocator.global_allocation() + if not success: + logger.warning("Global rebalance failed to produce a full pipeline") + else: + logger.debug("Global rebalance completed successfully") + self._bootstrapped = True + self._bootstrapped_event.set() + def leave(self, node_id: str) -> None: """Remove a node from allocation and refresh plan and materialized nodes.""" if node_id not in self.layer_allocator.node_id_to_node: @@ -316,19 +357,20 @@ def leave(self, node_id: str) -> None: f"Mixed assignment detected ({manual_count} manual, {total_count - manual_count} automatic); skipping rebalance" ) else: - # All nodes are automatic, proceed with rebalance - self._bootstrapped = False - self._bootstrapped_event.clear() - for n in self.nodes: - if n.start_layer is not None and n.end_layer is not None: - self.layer_allocator.deallocate(n) - success = self.layer_allocator.global_allocation() - if not success: - logger.warning("Global rebalance failed to produce a full pipeline") + # All nodes are automatic, try adjustment first, then rebalance if needed + if not self.layer_allocator.has_full_pipeline(): + logger.debug( + "No full pipeline after node leave, attempting warmup and truncate" + ) + self._run_warmup_and_truncate(override_warmup_count=1) + if not self.layer_allocator.has_full_pipeline(): + self._perform_global_rebalance() + else: + logger.debug( + "Pipeline recovered through warmup and truncate, skipping global rebalance" + ) else: - logger.debug("Global rebalance completed successfully") - self._bootstrapped = True - self._bootstrapped_event.set() + self._perform_global_rebalance() with self._node_count_cv: self._node_count_cv.notify_all() From 97370c814f7878dd7c65d65b9f912d77849ae230 Mon Sep 17 00:00:00 2001 From: jason Date: Tue, 11 Nov 2025 22:43:58 +0800 Subject: [PATCH 2/4] change after chris review --- src/scheduling/layer_allocation.py | 5 ++ src/scheduling/request_routing.py | 8 ++- src/scheduling/scheduler.py | 82 +++++++++++++++++------------- 3 files changed, 58 insertions(+), 37 deletions(-) diff --git a/src/scheduling/layer_allocation.py b/src/scheduling/layer_allocation.py index cee8799..d7dd1a4 100644 --- a/src/scheduling/layer_allocation.py +++ b/src/scheduling/layer_allocation.py @@ -203,6 +203,11 @@ def deallocate(self, node: Node) -> None: node.is_active = False self._update_layer_loads_heap() + def reallocate(self, node: Node, start_layer: int, end_layer: int) -> None: + """Reallocate a node to a specific layer range.""" + self.deallocate(node) + self.allocate(node, start_layer, end_layer) + def declare(self, node: Node) -> None: """Declare a node to the allocator.""" if node.node_id not in self.node_id_to_node: diff --git a/src/scheduling/request_routing.py b/src/scheduling/request_routing.py index 9bd91e1..e4d5e09 100644 --- a/src/scheduling/request_routing.py +++ b/src/scheduling/request_routing.py @@ -52,11 +52,17 @@ class DynamicProgrammingRouting(RequestRoutingStrategy): minimum-latency node sequence and total latency. """ - def find_turning_points(self, nodes: List[Node], num_layers: int) -> List[Tuple[str, int, str]]: + @staticmethod + def find_turning_points(nodes: List[Node], num_layers: int) -> List[Tuple[str, int, str]]: """Find shard truncation points via layer-level DP. DP state is (layer l, node i that hosts l). Node cost uses the node's per-layer latency proxy; edge cost uses RTT between nodes. + + This is a static method that can be called directly without creating an instance: + DynamicProgrammingRouting.find_turning_points(nodes, num_layers) + + It can also be called via an instance, which will work due to Python's method resolution. """ if num_layers <= 0 or not nodes: return [] diff --git a/src/scheduling/scheduler.py b/src/scheduling/scheduler.py index 3a7b0d2..277792b 100644 --- a/src/scheduling/scheduler.py +++ b/src/scheduling/scheduler.py @@ -116,25 +116,51 @@ def __init__( pass # Orchestration helpers - def bootstrap(self) -> bool: - """Bootstrapping: first-time layer allocation and optional warm-up. + def bootstrap(self, *, clear_existing: bool = False, skip_warmup: bool = False) -> bool: + """Bootstrapping: + This method can be used for both initial bootstrapping and global rebalancing. + When clear_existing=True, it first deallocates all existing allocations before + performing global allocation (rebalancing behavior). When clear_existing=False, + it performs allocation on top of existing state (initial bootstrapping behavior). - Returns True if a full pipeline was established; False otherwise. + Args: + clear_existing: If True, deallocate all existing allocations before reallocating. + This is used for global rebalancing. Default is False. + skip_warmup: If True, skip the warm-up and truncate step. Default is False. + + Returns: + True if a full pipeline was established; False otherwise. """ - if len(self.nodes) < self.min_nodes_bootstrapping: + # Check node count only for initial bootstrapping (not rebalancing) + if not clear_existing and len(self.nodes) < self.min_nodes_bootstrapping: logger.debug( f"Bootstrapping deferred: have {len(self.nodes)} nodes; need >= {self.min_nodes_bootstrapping}" ) return False - logger.debug("Bootstrapping layer allocator") + + # Clear existing allocations if this is a rebalance + if clear_existing: + logger.debug("Performing global rebalance (clearing existing allocations)") + self._bootstrapped = False + self._bootstrapped_event.clear() + for n in self.nodes: + if n.start_layer is not None and n.end_layer is not None: + self.layer_allocator.deallocate(n) + else: + logger.debug("Bootstrapping layer allocator") + + # Perform global allocation success = self.layer_allocator.global_allocation() if not success: - logger.warning("Bootstrapping failed to produce a full pipeline") + logger.warning("Global allocation failed to produce a full pipeline") return False + assignments = self.list_node_allocations() logger.debug(f"Layer allocator assignments: {assignments}") + # Optional warm-up to find turning points and truncate node ranges - if self.request_warm_up_for_reshard > 0: + # Skip warmup for rebalancing scenarios (can be overridden with skip_warmup=False) + if not skip_warmup and self.request_warm_up_for_reshard > 0: self._run_warmup_and_truncate() assignments = self.list_node_allocations() logger.debug(f"Layer allocator assignments after turn-point warm-up: {assignments}") @@ -142,9 +168,11 @@ def bootstrap(self) -> bool: if not self.layer_allocator.has_full_pipeline(): logger.warning("Bootstrapping failed to produce a full pipeline") return False + self._bootstrapped = True self._bootstrapped_event.set() - logger.debug("Bootstrapping completed successfully; full pipeline established") + action = "rebalance" if clear_existing else "bootstrapping" + logger.debug(f"{action.capitalize()} completed successfully; full pipeline established") return True def list_node_allocations(self) -> List[Tuple[str, int, int]]: @@ -178,19 +206,12 @@ def _run_warmup_and_truncate(self, override_warmup_count: int = 0) -> None: override_warmup_count if override_warmup_count > 0 else self.request_warm_up_for_reshard ) - # Always use DP router for finding turning points, regardless of current router type - # This is because turning points detection requires layer-level DP analysis - dp_router = DynamicProgrammingRouting() - agg_turns: Dict[Tuple[str, int, str], int] = {} for _ in range(warmup_count): - turns = dp_router.find_turning_points(nodes_list, num_layers) + turns = DynamicProgrammingRouting.find_turning_points(nodes_list, num_layers) for t in turns: agg_turns[t] = agg_turns.get(t, 0) + 1 - if not agg_turns: - return - # Apply truncation for consistently observed turning points # Note: Must use layer_allocator.allocate/deallocate to properly update # internal state (node_allocation dict and layer_to_load) @@ -201,12 +222,10 @@ def _run_warmup_and_truncate(self, override_warmup_count: int = 0) -> None: start, end = node.start_layer, node.end_layer if kind == "tail": if layer_idx < end: - self.layer_allocator.deallocate(node) - self.layer_allocator.allocate(node, start, layer_idx) + self.layer_allocator.reallocate(node, start, layer_idx) elif kind == "head": if layer_idx > start: - self.layer_allocator.deallocate(node) - self.layer_allocator.allocate(node, layer_idx, end) + self.layer_allocator.reallocate(node, layer_idx, end) def update_node_info( self, @@ -317,20 +336,11 @@ def join(self, node: Node, bootstrap: bool = False) -> None: self._node_count_cv.notify_all() def _perform_global_rebalance(self) -> None: - """Perform global rebalancing: deallocate all nodes and reallocate.""" - logger.debug("Performing global rebalance") - self._bootstrapped = False - self._bootstrapped_event.clear() - for n in self.nodes: - if n.start_layer is not None and n.end_layer is not None: - self.layer_allocator.deallocate(n) - success = self.layer_allocator.global_allocation() - if not success: - logger.warning("Global rebalance failed to produce a full pipeline") - else: - logger.debug("Global rebalance completed successfully") - self._bootstrapped = True - self._bootstrapped_event.set() + """Perform global rebalancing: deallocate all nodes and reallocate. + + This is a convenience wrapper around bootstrap(clear_existing=True, skip_warmup=True). + """ + self.bootstrap(clear_existing=True, skip_warmup=True) def leave(self, node_id: str) -> None: """Remove a node from allocation and refresh plan and materialized nodes.""" @@ -364,13 +374,13 @@ def leave(self, node_id: str) -> None: ) self._run_warmup_and_truncate(override_warmup_count=1) if not self.layer_allocator.has_full_pipeline(): - self._perform_global_rebalance() + self.bootstrap(clear_existing=True, skip_warmup=True) else: logger.debug( "Pipeline recovered through warmup and truncate, skipping global rebalance" ) else: - self._perform_global_rebalance() + self.bootstrap(clear_existing=True, skip_warmup=True) with self._node_count_cv: self._node_count_cv.notify_all() From 2a2c8a2b4bd08854ab5fd6d1e197cd5c5b15a62d Mon Sep 17 00:00:00 2001 From: jason Date: Tue, 11 Nov 2025 22:50:05 +0800 Subject: [PATCH 3/4] remove _perform_global_rebalance --- src/scheduling/scheduler.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/scheduling/scheduler.py b/src/scheduling/scheduler.py index 277792b..d644471 100644 --- a/src/scheduling/scheduler.py +++ b/src/scheduling/scheduler.py @@ -335,13 +335,6 @@ def join(self, node: Node, bootstrap: bool = False) -> None: with self._node_count_cv: self._node_count_cv.notify_all() - def _perform_global_rebalance(self) -> None: - """Perform global rebalancing: deallocate all nodes and reallocate. - - This is a convenience wrapper around bootstrap(clear_existing=True, skip_warmup=True). - """ - self.bootstrap(clear_existing=True, skip_warmup=True) - def leave(self, node_id: str) -> None: """Remove a node from allocation and refresh plan and materialized nodes.""" if node_id not in self.layer_allocator.node_id_to_node: From 93a99818b7ab3aa7af840a655bcc3eb915e88584 Mon Sep 17 00:00:00 2001 From: jason Date: Wed, 12 Nov 2025 21:33:33 +0800 Subject: [PATCH 4/4] add unit test for node rebalance --- tests/scheduler_tests/test_scheduler.py | 106 ++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/tests/scheduler_tests/test_scheduler.py b/tests/scheduler_tests/test_scheduler.py index 0695a67..7213bd8 100644 --- a/tests/scheduler_tests/test_scheduler.py +++ b/tests/scheduler_tests/test_scheduler.py @@ -133,3 +133,109 @@ def test_scheduler_single_node_leave_then_rejoin_reassigns_layers(): assert ( n1_rejoin.start_layer is not None and n1_rejoin.end_layer is not None ), "After re-join, single node should be assigned a full layer range" + + +def test_scheduler_three_nodes_sequential_join_leave_rejoin(): + """Test scheduler with 28-layer model, 3 nodes each capable of 22 layers. + + Scenario: + - 28-layer model + - n1, n2, n3 all can host 22 layers + - min_nodes_bootstrapping=2 + - n1, n2, n3 join sequentially + - n1 leaves and rejoins + - n2 leaves and rejoins + - n3 leaves and rejoins + """ + model = build_model_info(28) + + # Create nodes that can each host 22 layers + # Calculation: 100GB can host 16 layers, so 22 layers need ~137.5GB + # Using 150GB to ensure capacity for 22 layers with some margin + n1 = build_node("n1", model, tflops=312.0, mem_gb=138.0, x=0, y=0) + n2 = build_node("n2", model, tflops=312.0, mem_gb=138.0, x=1, y=0) + n3 = build_node("n3", model, tflops=312.0, mem_gb=138.0, x=2, y=0) + + # Verify nodes can host 22 layers + assert n1.get_decoder_layer_capacity() >= 22, "n1 should be able to host 22 layers" + assert n2.get_decoder_layer_capacity() >= 22, "n2 should be able to host 22 layers" + assert n3.get_decoder_layer_capacity() >= 22, "n3 should be able to host 22 layers" + + # Initialize scheduler with min_nodes_bootstrapping=2, no nodes initially + sched = Scheduler(model, [], strategy="dp", min_nodes_bootstrapping=2) + + # Step 1: n1 joins (not enough nodes yet) + sched.enqueue_join(n1) + sched._process_joins() # type: ignore[attr-defined] + assert len(sched.nodes) == 1 + assert not sched.layer_allocator.has_full_pipeline() + + # Step 2: n2 joins (now we have 2 nodes, should bootstrap) + sched.enqueue_join(n2) + sched._process_joins() # type: ignore[attr-defined] + set_rtt_from_coords(sched.nodes) + ok = sched.bootstrap() + assert ok, "Bootstrap should succeed with 2 nodes" + assert sched.layer_allocator.has_full_pipeline() + + # Step 3: n3 joins (dynamic join after bootstrap) + sched.enqueue_join(n3) + sched._process_joins() # type: ignore[attr-defined] + set_rtt_from_coords(sched.nodes) + assert n3.start_layer is not None and n3.end_layer is not None + assert len(sched.nodes) == 3 + + # Step 4: n1 leaves and rejoins + n1_id = n1.node_id + sched.leave(n1_id) + assert n1 not in sched.nodes + assert len(sched.nodes) == 2 + assert sched.layer_allocator.has_full_pipeline() + + # Rejoin n1 + n1_rejoin = build_node("n1", model, tflops=312.0, mem_gb=138.0, x=0, y=0) + sched.enqueue_join(n1_rejoin) + sched._process_joins() # type: ignore[attr-defined] + set_rtt_from_coords(sched.nodes) + assert n1_rejoin.start_layer is not None and n1_rejoin.end_layer is not None + assert len(sched.nodes) == 3 + assert sched.layer_allocator.has_full_pipeline() + + # Step 5: n2 leaves and rejoins + n2_id = n2.node_id + sched.leave(n2_id) + assert n2 not in sched.nodes + assert len(sched.nodes) == 2 + assert sched.layer_allocator.has_full_pipeline() + + # Rejoin n2 + n2_rejoin = build_node("n2", model, tflops=312.0, mem_gb=138.0, x=1, y=0) + sched.enqueue_join(n2_rejoin) + sched._process_joins() # type: ignore[attr-defined] + set_rtt_from_coords(sched.nodes) + assert n2_rejoin.start_layer is not None and n2_rejoin.end_layer is not None + assert len(sched.nodes) == 3 + assert sched.layer_allocator.has_full_pipeline() + + # Step 6: n3 leaves and rejoins + n3_id = n3.node_id + sched.leave(n3_id) + assert n3 not in sched.nodes + assert len(sched.nodes) == 2 + assert sched.layer_allocator.has_full_pipeline() + + # Rejoin n3 + n3_rejoin = build_node("n3", model, tflops=312.0, mem_gb=138.0, x=2, y=0) + sched.enqueue_join(n3_rejoin) + sched._process_joins() # type: ignore[attr-defined] + set_rtt_from_coords(sched.nodes) + assert n3_rejoin.start_layer is not None and n3_rejoin.end_layer is not None + assert len(sched.nodes) == 3 + assert sched.layer_allocator.has_full_pipeline() + + # Final verification: all nodes should have layer assignments + allocations = sched.list_node_allocations() + assert len(allocations) == 3, "All 3 nodes should have layer assignments" + # Verify full pipeline coverage + total_covered = sum(e - s for _, s, e in allocations) + assert total_covered >= model.num_layers, "All layers should be covered"