diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index 32ae69906..94a4affe8 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -80,6 +80,10 @@ class Generator(ForgeActor): sampling_params (SamplingParams): The sampling parameters to use for the vLLM engine. use_dcp_for_weight_sync (bool): Whether to use DCP for NFS-based weight sync. Default depends on whether or not RDMA is enabled in torchstore. If it is, then DCP is disabled. Otherwise, DCP is enabled. + wait_for_pending_on_update (bool): If True (default), weight updates will block new requests and wait + for pending requests to complete. If False, enables in-flight weight updates. + reset_cache_on_update (bool): If True (default), resets the KV prefix cache after weight updates. + Set to False to preserve cache during updates. Example: >>> generator = await Generator.options(procs=1, num_replicas=1, with_gpus=True).as_service( @@ -97,6 +101,8 @@ class Generator(ForgeActor): use_dcp_for_weight_sync: bool | None = None prefetch_weights_to_shm: bool = True n_fetcher_procs: int = 8 + wait_for_pending_on_update: bool = True + reset_cache_on_update: bool = True def __post_init__(self): super().__init__() @@ -426,19 +432,40 @@ async def run(self) -> None: self.request_lock.notify_all() @endpoint - async def update_weights(self, version: int) -> None: + async def update_weights( + self, + version: int, + *, + wait_for_pending: bool | None = None, + reset_cache: bool | None = None, + ) -> None: """Update weights on base model from a generator version to be found in a torchstore volume. Args: generator_version (int): Generator version from which to update. This will correspond to a key in a torchstore volume. + wait_for_pending (bool | None): If True, blocks new requests and waits for pending requests to + complete before updating weights. If False, updates weights immediately without waiting, + allowing in-flight updates (requests may see new weights mid-generation). + If None (default), uses the value from wait_for_pending_on_update config. + reset_cache (bool | None): If True, resets the KV prefix cache after updating weights. + Set to False to preserve cache when doing in-flight updates. + If None (default), uses the value from reset_cache_on_update config. Example: >>> trainer.train_step(...) >>> version += 1 >>> await trainer.push_weights() + >>> # Uses config defaults >>> generator.update_weights(version) + >>> # Override config for this call + >>> generator.update_weights(version, wait_for_pending=False, reset_cache=False) """ + # Use config defaults if not explicitly provided + if wait_for_pending is None: + wait_for_pending = self.wait_for_pending_on_update + if reset_cache is None: + reset_cache = self.reset_cache_on_update # TODO: enable shared memory prefetch for DCP-based weight sync if self.prefetch_weights_to_shm and not self.use_dcp_for_weight_sync: logger.info(f"[Generator] Fetching weights for v{version} to shared memory") @@ -448,27 +475,30 @@ async def update_weights(self, version: int) -> None: # Serialize updates (only one update at a time) async with self.update_lock: # Grab the lock to stop accepting requests and wait on pending requests - async with self.request_lock: - self.accepting_requests = False - curr_requests = [fut for _, fut in self.requests.values()] - if curr_requests: - # Record pending requests metrics - record_metric( - "generator_perf/update_weights/avg_pending_requests", - len(curr_requests), - Reduce.MEAN, - ) - record_metric( - "generator_perf/update_weights/max_pending_requests", - len(curr_requests), - Reduce.MAX, - ) - logger.debug(f"Waiting for {len(curr_requests)} pending requests") - - # Wait until all pending requests have been processed - # TODO: If generating long sequences, this might be long and will block - # generator weight updates - await self.request_lock.wait_for(lambda: len(self.requests) == 0) + if wait_for_pending: + async with self.request_lock: + self.accepting_requests = False + curr_requests = [fut for _, fut in self.requests.values()] + if curr_requests: + # Record pending requests metrics + record_metric( + "generator_perf/update_weights/avg_pending_requests", + len(curr_requests), + Reduce.MEAN, + ) + record_metric( + "generator_perf/update_weights/max_pending_requests", + len(curr_requests), + Reduce.MAX, + ) + logger.debug( + f"Waiting for {len(curr_requests)} pending requests" + ) + + # Wait until all pending requests have been processed + # TODO: If generating long sequences, this might be long and will block + # generator weight updates + await self.request_lock.wait_for(lambda: len(self.requests) == 0) # Record weight update metrics record_metric( @@ -492,12 +522,14 @@ async def update_weights(self, version: int) -> None: self.generator_version = version # After updating the weights, we need to reset the KV cache - self.scheduler.reset_prefix_cache() + if reset_cache: + self.scheduler.reset_prefix_cache() # Resume accepting requests and wake up any waiting generate() calls - async with self.request_lock: - self.accepting_requests = True - self.request_lock.notify_all() + if wait_for_pending: + async with self.request_lock: + self.accepting_requests = True + self.request_lock.notify_all() logger.info(f"Weight update completed (now v{self.generator_version})")