Skip to content

Conversation

@iRonJ
Copy link

@iRonJ iRonJ commented Oct 22, 2025

Add MaskedRepetitionContext for VLM Image Token Exclusion

Overview

This PR introduces MaskedRepetitionContext, a new LogitProcessor that extends the existing repetition penalty functionality to support excluding specific tokens (such as image tokens in Vision-Language Models) from repetition penalties.

Problem

In Vision-Language Models (VLMs), image patch tokens often need to repeat naturally to represent visual content. The existing RepetitionContext applies penalties to all repeated tokens, which can degrade VLM performance by incorrectly penalizing legitimate image token repetitions.

Solution

MaskedRepetitionContext accepts a boolean mask array that identifies which tokens should be excluded from repetition penalty calculation, allowing:

  • Text tokens: Receive normal repetition penalty to maintain quality
  • Image tokens: Repeat freely without penalty to preserve visual understanding

Basic Usage Example

import MLX
import MLXLMCommon

// Create a MaskedRepetitionContext processor
var processor = MaskedRepetitionContext(
    repetitionPenalty: 1.1,  // Apply 10% penalty to repeated tokens
    repetitionContextSize: 20 // Consider last 20 tokens for repetition
)

// Example prompt tokens where token 32000 is an image token
let promptTokens = [1, 15, 32000, 32000, 42, 123] // 32000 = image token
let imageMask = [false, false, true, true, false, false] // true = exclude from penalty

// Initialize the processor with prompt and mask
let promptArray = MLXArray(promptTokens)
processor.prompt(promptArray, mask: imageMask)

// During generation: only tokens [1, 15, 42, 123] will be penalized
// Image tokens [32000, 32000] can repeat without penalty

Integration with TokenIterator

// Use with TokenIterator for generation
let sampler = CategoricalSampler(temperature: 0.7)
let iterator = try TokenIterator(
    input: lmInput,
    model: model,
    processor: processor,  // Your MaskedRepetitionContext
    sampler: sampler,
    maxTokens: 100
)

// Generate tokens - image tokens won't be penalized even if they repeat
for try await token in iterator {
    let tokenId = token.item(Int.self)
    let isImageToken = (tokenId == imageTokenId) 
    
    // Update processor with mask information for new tokens
    processor.didSample(token: token, isMasked: isImageToken)
}

Files Changed

  • Evaluate.swift: Added MaskedRepetitionContext implementation
  • Tests/MLXLMTests/RepetitionPenaltyTests.swift: Comprehensive test suite
  • mlx-swift-examples.xcodeproj/project.pbxproj: Added test file to build system

Key Features

Backward Compatible: Implements same LogitProcessor interface as RepetitionContext
Flexible Masking: Support any token types that should be excluded from penalty
Efficient Implementation: Uses circular buffer with O(1) operations
VLM Optimized: Designed specifically for Vision-Language Model requirements
Comprehensive Testing: Full test coverage including edge cases

Testing

Running the Tests

To run the comprehensive test suite for repetition penalty functionality:

# Run all tests
xcodebuild test -scheme mlx-libraries-Package

# Run specific repetition penalty tests
xcodebuild test -scheme mlx-libraries-Package -only-testing:MLXLMTests.RepetitionPenaltyTests

What We're Testing

The test suite (RepetitionPenaltyTests.swift) validates:

  1. testBasicRepetitionContext: Verifies existing RepetitionContext functionality remains intact
  2. testMaskedRepetitionContextBasic: Tests basic masking behavior - masked tokens are excluded from penalty
  3. testMaskedRepetitionContextAllMasked: Edge case where all tokens are masked (no penalties applied)
  4. testMaskedRepetitionContextDuringGeneration: Complex scenario simulating actual generation with mixed masked/unmasked tokens
  5. testMaskedRepetitionContextCircularBuffer: Validates circular buffer behavior when context window is exceeded
  6. testMaskedRepetitionContextFallbackBehavior: Tests backward compatibility when no mask is provided
  7. testMaskedRepetitionContextPreconditions: Validates error handling and input validation
  8. testComparisonBetweenProcessors: Direct comparison between RepetitionContext and MaskedRepetitionContext behavior

Test Coverage Highlights

  • Penalty Application Logic: Verifies correct penalty calculation (division for positive logits, multiplication for negative)
  • Mask Handling: Ensures only unmasked tokens receive penalties
  • Memory Management: Tests circular buffer behavior and context window management
  • Edge Cases: Handles empty contexts, all-masked scenarios, and boundary conditions
  • Integration: Validates compatibility with existing MLX generation pipeline
  • Performance: Confirms O(1) token operations and efficient mask processing

Benefits for VLMs

  • Improved Generation Quality: VLMs can now apply repetition penalties selectively
  • Better Image Understanding: Image tokens repeat naturally without artificial constraints
  • Maintained Text Quality: Text tokens still receive appropriate repetition penalties
  • Easy Integration: Drop-in replacement for existing repetition penalty usage

Breaking Changes

None. This is a purely additive feature that maintains full backward compatibility with existing RepetitionContext usage.

@davidkoski
Copy link
Collaborator

This is failing the swift-format check. Please make sure you have 602.0.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants