-
Notifications
You must be signed in to change notification settings - Fork 56
Add Trainer Protocol #533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Trainer Protocol #533
Conversation
joecummings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall very happy with this but I need to see more details in the docstrings bc its not entirely clear whats being passed into each method.
|
I like protocol more than ABC |
joecummings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General idea is good, but there's a lot of small things that don't quite make sense.
In general, lets start the interface based solidly in our use case. E.g. let's not include params or methods that are not currently in use.
| """Protocol defining the standard interface for all Forge trainers.""" | ||
|
|
||
| async def forward_backward( | ||
| self, batch: TextTrainBatch, loss_fn: LossFn | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As shared offline, this definition as-is with loss_fn may not be flexible to support different losses. Some losses may expect different outputs from the model forward (ie logits for cross entropy, last hidden state + final lm_head for cut cross entropy, etc)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modified the LossFn to take a dict of inputs and outputs for now. I'm not really sure how we can make this more specific until we see the space of obscure use cases, but let me know if this is still too restrictive
This PR introduces the
TrainerProtocol insrc/forge/api/trainer.py, establishing a unified training interface that all trainer implementations will conform to.Motivation
Currently, Forge applications directly use Monarch actors (e.g.,
RLTrainer.options(...).as_actor(...)), which exposes implementation details like.route()and.fanout()to application code.This creates tight coupling and makes it difficult to:
Protocol + Wrappers
Note that we're using Python's Protocol, and not ABC! In case you weren't aware, there is a big philosophical debate about ABC vs Protocol that Claude has introduced me to. I'm primarily choosing Protocol because it's lighter weight (and let me know if you disagree strongly).
Why Protocol and not ABC?
We want a public, lightweight interface that anyone can satisfy without importing our base class. Protocol gives us structural typing: if it quacks, it flies. An ABC would force nominal inheritance and encourage a hierarchy we don’t actually need.
TL;DR:
@runtime_checkableenablesisinstance(x, Trainer)as a light guard.What it looks like in practice:
With ABC:
With Protocol:
Why this matters:
HFTrainer,TitanTrainer, etc. can be simple adapters over the Monarch actorBaseTrainer(ABC)with helpers/metrics—without changing the publicTrainerprotocol.Other planned changes:
The Protocol is step 1 of a multi-PR refactor. Other planned changes: