diff --git a/dspy/teleprompt/simba_utils.py b/dspy/teleprompt/simba_utils.py index fd5c3e8808..99e5e70578 100644 --- a/dspy/teleprompt/simba_utils.py +++ b/dspy/teleprompt/simba_utils.py @@ -12,6 +12,15 @@ logger = logging.getLogger(__name__) def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings: dict | None = None): + """Prepares a list of language models for resampling by assigning unique rollout IDs. + + Creates n models with sequential rollout IDs. If teacher_settings is provided, the first + model uses the teacher's language model configuration. Remaining models are copies of the + base model with temperature set to 1.0. + + Returns: + A list of language models configured for resampling with unique rollout IDs. + """ lm = program.get_lm() or dspy.settings.lm start_rollout_id = lm.kwargs.get("rollout_id", 0) @@ -32,7 +41,26 @@ def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings return models def wrap_program(program: dspy.Module, metric: Callable): + """Wraps a program to capture its execution trace and evaluate it with a metric. + + Returns a function that executes the program on an example, captures the trace, + evaluates the prediction using the metric, and returns a dictionary containing + the prediction, trace, score, example, and any additional metadata from the metric. + The metric can return a numeric score or a dspy.Prediction with a score field. + + Returns: + A function that takes an example and returns a dictionary with prediction results, + trace, score, and metadata. + """ def wrapped_program(example): + """Executes the program on an example and captures its trace. + + Runs the program with the given example, captures the execution trace, evaluates + the result using the metric, and packages everything into a result dictionary. + + Returns: + A dictionary containing prediction, trace, score, example, and output_metadata. + """ with dspy.context(trace=[]): prediction, trace, score = None, None, 0.0 try: @@ -71,7 +99,25 @@ def wrapped_program(example): return wrapped_program def append_a_demo(demo_input_field_maxlen): + """Returns a function that appends demonstrations from a successful trajectory to predictors. + + The returned function extracts demonstrations from the best trajectory in a bucket and + appends them to the corresponding predictors. Input fields longer than demo_input_field_maxlen + are truncated. Skips appending if the best score is at or below the 10th percentile. + + Returns: + A function that processes a bucket and appends demonstrations to predictors. + """ def append_a_demo_(bucket, system, **kwargs): + """Extracts and appends demonstrations from the best trajectory to predictors. + + Processes the highest-scoring trajectory in the bucket, creates demonstrations from + each step, and appends them to the corresponding predictors. Truncates long input + fields and skips if the score is too low. + + Returns: + True if demonstrations were appended, False if skipped due to low score. + """ predictor2name, name2predictor = kwargs["predictor2name"], kwargs["name2predictor"] batch_10p_score = kwargs["batch_10p_score"] @@ -104,6 +150,16 @@ def append_a_demo_(bucket, system, **kwargs): def append_a_rule(bucket, system, **kwargs): + """Generates and appends advice to predictor instructions by comparing good and bad trajectories. + + Uses a language model to analyze the difference between a high-scoring and low-scoring + trajectory, generating module-specific advice. The advice is appended to each predictor's + instructions. Skips rule generation if the good score is too low or the bad score is too high + relative to batch percentiles. + + Returns: + True if advice was generated and appended, False if skipped due to score thresholds. + """ predictor2name = kwargs["predictor2name"] batch_10p_score, batch_90p_score = kwargs["batch_10p_score"], kwargs["batch_90p_score"] prompt_model = kwargs["prompt_model"] or dspy.settings.lm @@ -168,18 +224,11 @@ def append_a_rule(bucket, system, **kwargs): return True class OfferFeedback(dspy.Signature): - """ - You will be given two trajectories of an LLM-driven program's execution. Your goal is to help the program's modules - build up experience on how to maximize the reward value assigned to the program's outputs if it were to receive - similar inputs in the future. - - The module won't see its own history. It will rely on your advice balancing being concrete and being generalizable. - - In your advice: - - Avoid boilerplate. Offer advice that would change the module's behavior for the better in the future. - - Ensure that advice offered to a module M is specific to that M's specific sub-task, not the overall program. - - Rely on contrasting the behavior of the worse trajectory against the better trajectory in making recommendations. - - Ensure each unique module name appears exactly once as a key in the advice dictionary. + """Signature for generating module-specific advice by comparing successful and unsuccessful trajectories. + + Analyzes two program execution trajectories with different reward values to generate + concrete, actionable advice for each module. The advice helps modules improve their + behavior by learning from the contrast between better and worse trajectories. """ program_code: str = InputField(desc="The code of the program that we are analyzing") @@ -208,6 +257,15 @@ class OfferFeedback(dspy.Signature): ) def inspect_modules(program): + """Formats module information into a human-readable string representation. + + Extracts and formats each predictor's name, input fields, output fields, and instructions + into a structured text format with separators. The output is suitable for inclusion in + prompts or logs. + + Returns: + A formatted string containing module definitions with their fields and instructions. + """ separator = "-" * 80 output = [separator] @@ -228,6 +286,15 @@ def inspect_modules(program): def recursive_mask(o): + """Recursively masks non-serializable objects with placeholder strings. + + Traverses the object structure and replaces any non-JSON-serializable values with + a placeholder string indicating the type. Handles dictionaries, lists, and tuples + recursively while preserving already-serializable values. + + Returns: + The object with non-serializable values replaced by placeholder strings. + """ # If the object is already serializable, return it. try: orjson.dumps(o)