Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions apps/on_policy_distillation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# On-Policy Distillation for Math Reasoning

This app implements on-policy distillation (OPD) following the approach described in the [Thinking Machines blog post](https://thinkingmachines.ai/blog/on-policy-distillation/). OPD combines the benefits of on-policy training with dense reward signals for efficient post-training.

## Overview

On-policy distillation trains a student model by:
1. Sampling trajectories from the student model itself
2. Using a teacher model to grade each token with dense rewards (per-token KL divergence)
3. Training the student to minimize reverse KL with the teacher

This approach is **10-30x more compute efficient** than traditional RL while achieving comparable or better performance.

## Experimental Setup

### Models
- **Student**: Qwen3-1.7B-Base (or Qwen3-8B for larger experiments)
- **Teacher**: Qwen3-8B (or Qwen3-32B)
- **Evaluation**: AIME'24 benchmark

### Training Pipeline

#### Phase 1: Supervised Fine-Tuning (SFT)
First, establish a strong baseline through off-policy distillation:

```bash
python -m apps.sft.main --config apps/sft/qwen3_1_7b.yaml
```

- **Dataset**: OpenThoughts3-1.2M (400k prompts)
- **Expected Performance**: ~40% on AIME'24
- **Purpose**: Teaches the model basic math reasoning patterns

#### Phase 2: On-Policy Distillation
Refine the model using on-policy learning with dense supervision:

```bash
python -m apps.on-policy-distillation.main --config apps/on-policy-distillation/qwen_opd.yaml
```

- **Starting Point**: SFT checkpoint from Phase 1
- **Dataset**: Math prompts (from OpenThoughts3 or DeepMath, but only prompts - not solutions)
- **Training**: ~150-200 steps (77k prompts with 4 samples each)
- **Expected Performance**: ~50% on AIME'24

### Key Implementation Details

1. **Loss Function**: Per-token reverse KL divergence
```python
reverse_kl = -(student_logprobs - teacher_logprobs)
```

2. **Sampling**: Generate multiple trajectories per prompt (n=16 in config)

3. **No Discount Factor**: Optimize only immediate next token (discount=0)

4. **Efficient Batching**: Can use smaller batch sizes than RL due to dense rewards

## Key Advantages

- **Compute Efficiency**: 10-30x reduction vs traditional RL
- **Dense Supervision**: Learns from every token, not just final rewards
- **Data Efficiency**: Can reuse prompts multiple times effectively
- **Stability**: More stable training than sparse RL rewards

## Notes for Reproduction

1. **Ensure proper initialization**: Load the SFT checkpoint before starting OPD
2. **Use prompts only**: During OPD, sample completions from student, don't use dataset solutions
3. **Teacher quality matters**: Better teachers provide better supervision
4. **Monitor reverse KL**: Should go to near-zero as training progresses

## References

- [On-Policy Distillation Blog Post](https://thinkingmachines.ai/blog/on-policy-distillation/)
- [Tinker Cookbook](https://github.com/thinking-machines-lab/tinker-cookbook)
- [OpenThoughts3 Dataset](https://huggingface.co/datasets/open-thoughts/OpenThoughts3-1.2M)
120 changes: 120 additions & 0 deletions apps/on_policy_distillation/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# On-Policy Distillation: Qwen 1.7B (student) learning from Qwen 8B (teacher)
# >>> python -m apps.on_policy_distillation.main --config apps/on_policy_distillation/qwen_1_7b_to_8b.yaml

# Global configuration
train_batch_size: 16 # Number of trajectories per training step
max_req_tokens: 2048
max_res_tokens: 4096
student_model: "./Qwen3-1.7B-Base-SFT" # Path to base model SFT'd on a math dataset
teacher_model: "Qwen/Qwen3-8B"

# Observability configuration
metric_logging:
wandb:
project: opd-training
group: opd_exp_${oc.env:USER}
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
console:
logging_mode: global_reduce

# Dataset configuration
dataset:
path: "zwhe99/DeepMath-103K"
split: "train"

# Student generation configuration
student_generator:
engine_args:
model: ${student_model}
tensor_parallel_size: 1
pipeline_parallel_size: 1
enforce_eager: false
sampling_params:
n: 4
max_tokens: ${max_res_tokens}
temperature: 0.6
top_p: 0.95

# Student training configuration
trainer:
model:
name: qwen3
flavor: 1.7B
hf_assets_path: hf://${student_model}
optimizer:
name: AdamW
lr: 5e-5
eps: 1e-8
lr_scheduler:
warmup_steps: 0
training:
local_batch_size: ${train_batch_size} # Per-device batch size
seq_len: 8192
max_norm: 1.0
steps: 200
dtype: bfloat16
gc_freq: 5
compile:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
disable_loss_parallel: true
checkpoint:
enable: true
folder: ./checkpoint-opd
initial_load_path: ${student_model}
initial_load_model_only: true
initial_load_in_hf: true
last_save_in_hf: true
interval: 50
async_mode: "disabled"
activation_checkpoint:
mode: selective
selective_ac_option: op

# Teacher model configuration
teacher:
model:
name: qwen3
flavor: 8B
hf_assets_path: hf://${teacher_model}
training:
seq_len: ${trainer.training.seq_len}
dtype: bfloat16
gc_freq: 10
compile:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
checkpoint:
enable: true
initial_load_path: hf://${teacher_model}
initial_load_in_hf: true

# Resource allocations
services:
student_generator:
procs: 1
num_replicas: 4
mesh_name: student_generator
with_gpus: true
teacher:
procs: 1
num_replicas: 2
mesh_name: teacher
with_gpus: true
trainer:
procs: 1
num_replicas: 1
mesh_name: trainer
with_gpus: true
Loading
Loading