Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 58 additions & 26 deletions src/forge/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class Generator(ForgeActor):
sampling_params (SamplingParams): The sampling parameters to use for the vLLM engine.
use_dcp_for_weight_sync (bool): Whether to use DCP for NFS-based weight sync. Default depends on
whether or not RDMA is enabled in torchstore. If it is, then DCP is disabled. Otherwise, DCP is enabled.
wait_for_pending_on_update (bool): If True (default), weight updates will block new requests and wait
for pending requests to complete. If False, enables in-flight weight updates.
reset_cache_on_update (bool): If True (default), resets the KV prefix cache after weight updates.
Set to False to preserve cache during updates.

Example:
>>> generator = await Generator.options(procs=1, num_replicas=1, with_gpus=True).as_service(
Expand All @@ -97,6 +101,8 @@ class Generator(ForgeActor):
use_dcp_for_weight_sync: bool | None = None
prefetch_weights_to_shm: bool = True
n_fetcher_procs: int = 8
wait_for_pending_on_update: bool = True
reset_cache_on_update: bool = True

def __post_init__(self):
super().__init__()
Expand Down Expand Up @@ -426,19 +432,40 @@ async def run(self) -> None:
self.request_lock.notify_all()

@endpoint
async def update_weights(self, version: int) -> None:
async def update_weights(
self,
version: int,
*,
wait_for_pending: bool | None = None,
reset_cache: bool | None = None,
) -> None:
"""Update weights on base model from a generator version to be found in a torchstore volume.

Args:
generator_version (int): Generator version from which to update. This will correspond to a key in a
torchstore volume.
wait_for_pending (bool | None): If True, blocks new requests and waits for pending requests to
complete before updating weights. If False, updates weights immediately without waiting,
allowing in-flight updates (requests may see new weights mid-generation).
If None (default), uses the value from wait_for_pending_on_update config.
reset_cache (bool | None): If True, resets the KV prefix cache after updating weights.
Set to False to preserve cache when doing in-flight updates.
If None (default), uses the value from reset_cache_on_update config.

Example:
>>> trainer.train_step(...)
>>> version += 1
>>> await trainer.push_weights()
>>> # Uses config defaults
>>> generator.update_weights(version)
>>> # Override config for this call
>>> generator.update_weights(version, wait_for_pending=False, reset_cache=False)
"""
# Use config defaults if not explicitly provided
if wait_for_pending is None:
wait_for_pending = self.wait_for_pending_on_update
if reset_cache is None:
reset_cache = self.reset_cache_on_update
# TODO: enable shared memory prefetch for DCP-based weight sync
if self.prefetch_weights_to_shm and not self.use_dcp_for_weight_sync:
logger.info(f"[Generator] Fetching weights for v{version} to shared memory")
Expand All @@ -448,27 +475,30 @@ async def update_weights(self, version: int) -> None:
# Serialize updates (only one update at a time)
async with self.update_lock:
# Grab the lock to stop accepting requests and wait on pending requests
async with self.request_lock:
self.accepting_requests = False
curr_requests = [fut for _, fut in self.requests.values()]
if curr_requests:
# Record pending requests metrics
record_metric(
"generator_perf/update_weights/avg_pending_requests",
len(curr_requests),
Reduce.MEAN,
)
record_metric(
"generator_perf/update_weights/max_pending_requests",
len(curr_requests),
Reduce.MAX,
)
logger.debug(f"Waiting for {len(curr_requests)} pending requests")

# Wait until all pending requests have been processed
# TODO: If generating long sequences, this might be long and will block
# generator weight updates
await self.request_lock.wait_for(lambda: len(self.requests) == 0)
if wait_for_pending:
async with self.request_lock:
self.accepting_requests = False
curr_requests = [fut for _, fut in self.requests.values()]
if curr_requests:
# Record pending requests metrics
record_metric(
"generator_perf/update_weights/avg_pending_requests",
len(curr_requests),
Reduce.MEAN,
)
record_metric(
"generator_perf/update_weights/max_pending_requests",
len(curr_requests),
Reduce.MAX,
)
logger.debug(
f"Waiting for {len(curr_requests)} pending requests"
)

# Wait until all pending requests have been processed
# TODO: If generating long sequences, this might be long and will block
# generator weight updates
await self.request_lock.wait_for(lambda: len(self.requests) == 0)

# Record weight update metrics
record_metric(
Expand All @@ -492,12 +522,14 @@ async def update_weights(self, version: int) -> None:
self.generator_version = version

# After updating the weights, we need to reset the KV cache
self.scheduler.reset_prefix_cache()
if reset_cache:
self.scheduler.reset_prefix_cache()

# Resume accepting requests and wake up any waiting generate() calls
async with self.request_lock:
self.accepting_requests = True
self.request_lock.notify_all()
if wait_for_pending:
async with self.request_lock:
self.accepting_requests = True
self.request_lock.notify_all()

logger.info(f"Weight update completed (now v{self.generator_version})")

Expand Down
Loading