Skip to content

Conversation

@allenwang28
Copy link
Contributor

This PR introduces the Trainer Protocol in src/forge/api/trainer.py, establishing a unified training interface that all trainer implementations will conform to.

Motivation

Currently, Forge applications directly use Monarch actors (e.g., RLTrainer.options(...).as_actor(...)), which exposes implementation details like .route() and .fanout() to application code.

This creates tight coupling and makes it difficult to:

  • Switch between different trainer backends (TorchTitan, HuggingFace, etc.)
  • Write portable application code that doesn't depend on Monarch specifics

Protocol + Wrappers

Note that we're using Python's Protocol, and not ABC! In case you weren't aware, there is a big philosophical debate about ABC vs Protocol that Claude has introduced me to. I'm primarily choosing Protocol because it's lighter weight (and let me know if you disagree strongly).

Why Protocol and not ABC?

We want a public, lightweight interface that anyone can satisfy without importing our base class. Protocol gives us structural typing: if it quacks, it flies. An ABC would force nominal inheritance and encourage a hierarchy we don’t actually need.

TL;DR:

  • Looser coupling: Call sites accept “anything with the right methods,” not “anything that inherits our base.”
  • Frictionless third-party impls: External teams can implement the interface without depending on our internals.
  • Small, composable capabilities: Easy to define narrow traits and mix them.
  • Optional runtime checks: If desired, @runtime_checkable enables isinstance(x, Trainer) as a light guard.

What it looks like in practice:

With ABC:

 # Would force inheritance
  class TitanTrainer(Trainer):  # Must inherit from ABC
      def __init__(self, actor_handle):
          super().__init__()  # ABC initialization overhead
          self._actor = actor_handle

With Protocol:

  # No inheritance required
  class TitanTrainer:  # Just a plain class
      def __init__(self, actor_handle):
          self._actor = actor_handle  # That's it

Why this matters:

  • Simple/thin wrappers: HFTrainer, TitanTrainer, etc. can be simple adapters over the Monarch actor
  • Fungibility by default: Third parties can drop in their own trainer without subclassing anything.
  • Stability for callers: Callers type against the behavior (the protocol), so internal refactors don’t cascade.
  • Escape hatch: If we later need shared behavior, we can add an optional BaseTrainer(ABC) with helpers/metrics—without changing the public Trainer protocol.
  • Ultimately this all allows us to keep looser coupling between the protocol definition and implementation.

Other planned changes:

The Protocol is step 1 of a multi-PR refactor. Other planned changes:

  • Restructure actor/trainer.py into trainer/titan.py and rename RLTrainer to TitanTrainerActor. Add TitanTrainer wrapper class that hides Monarch adverbs
  • Implement the rest of the API for titan trainer (we only do train_step right now)
  • App migration - maybe after the other API changes have landed

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 6, 2025
Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall very happy with this but I need to see more details in the docstrings bc its not entirely clear whats being passed into each method.

@casteryh
Copy link
Contributor

casteryh commented Nov 6, 2025

I like protocol more than ABC

Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General idea is good, but there's a lot of small things that don't quite make sense.

In general, lets start the interface based solidly in our use case. E.g. let's not include params or methods that are not currently in use.

"""Protocol defining the standard interface for all Forge trainers."""

async def forward_backward(
self, batch: TextTrainBatch, loss_fn: LossFn | None = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As shared offline, this definition as-is with loss_fn may not be flexible to support different losses. Some losses may expect different outputs from the model forward (ie logits for cross entropy, last hidden state + final lm_head for cut cross entropy, etc)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified the LossFn to take a dict of inputs and outputs for now. I'm not really sure how we can make this more specific until we see the space of obscure use cases, but let me know if this is still too restrictive

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants