Skip to content
Open
112 changes: 89 additions & 23 deletions mesa_frames/concrete/datacollector.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,15 @@ def step(self):
from collections.abc import Callable
from mesa_frames import Model
from psycopg2.extensions import connection
import logging


class DataCollector(AbstractDataCollector):
def __init__(
self,
model: Model,
model_reporters: dict[str, Callable] | None = None,
agent_reporters: dict[str, str | Callable] | None = None,
agent_reporters: dict[str, str | Callable] | None = None, # <-- ALLOWS CALLABLE
trigger: Callable[[Any], bool] | None = None,
reset_memory: bool = True,
storage: Literal[
Expand All @@ -91,7 +92,10 @@ def __init__(
model_reporters : dict[str, Callable] | None
Functions to collect data at the model level.
agent_reporters : dict[str, str | Callable] | None
Attributes or functions to collect data at the agent level.
(MODIFIED) A dictionary mapping new column names to existing
column names (str) or callables. Callables are not currently
processed by the agent data collector but are allowed for API compatibility.
Example: {"agent_wealth": "wealth", "age_in_years": "age"}
trigger : Callable[[Any], bool] | None
A function(model) -> bool that determines whether to collect data.
reset_memory : bool
Expand All @@ -105,6 +109,18 @@ def __init__(
max_worker : int
Maximum number of worker threads used for flushing collected data asynchronously
"""
if agent_reporters:
for key, value in agent_reporters.items():
if not isinstance(key, str):
raise TypeError(
f"Agent reporter keys must be strings (the final column name), not a {type(key)}."
)
if not (isinstance(value, str) or callable(value)):
raise TypeError(
f"Agent reporter for '{key}' must be either a string (the source column name) "
f"or a callable (function taking an agent and returning a value), not a {type(value)}."
)

super().__init__(
model=model,
model_reporters=model_reporters,
Expand Down Expand Up @@ -174,25 +190,71 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
"""
Collect agent-level data using the agent_reporters.

Constructs a LazyFrame with one column per reporter and
includes `step` and `seed` metadata. Appends it to internal storage.
This method iterates through all AgentSets in the model, selects the
`unique_id` and the requested reporter columns from each AgentSet's
DataFrame, adds an `agent_type` column, and concatenates them
into a single "long" format LazyFrame.
"""
agent_data_dict = {}
for col_name, reporter in self._agent_reporters.items():
if isinstance(reporter, str):
for k, v in self._model.sets[reporter].items():
agent_data_dict[col_name + "_" + str(k.__class__.__name__)] = v
else:
agent_data_dict[col_name] = reporter(self._model)
agent_lazy_frame = pl.LazyFrame(agent_data_dict)
agent_lazy_frame = agent_lazy_frame.with_columns(
all_agent_frames = []
reporter_map = self._agent_reporters

try:
agent_sets_list = self._model.sets._agentsets
except AttributeError:
logging.error(
"DataCollector could not find '_agentsets' attribute on model.sets. "
"Agent data collection will be skipped."
)
return

for agent_set in agent_sets_list:
if not hasattr(agent_set, "df"):
logging.warning(
f"AgentSet {agent_set.__class__.__name__} has no 'df' attribute. Skipping."
)
continue

agent_df = agent_set.df.lazy()
agent_type = agent_set.__class__.__name__
available_cols = agent_df.columns

if "unique_id" not in available_cols:
logging.warning(
f"AgentSet {agent_type} 'df' has no 'unique_id' column. Skipping."
)
continue

cols_to_select = [pl.col("unique_id")]

for final_name, source_col in reporter_map.items():
if source_col in available_cols:
## Add the column, aliasing it if the key is different
cols_to_select.append(pl.col(source_col).alias(final_name))

## Only proceed if we have more than just unique_id
if len(cols_to_select) > 1:
set_frame = agent_df.select(cols_to_select)
## Add the agent_type column
set_frame = set_frame.with_columns(
pl.lit(agent_type).alias("agent_type")
)
all_agent_frames.append(set_frame)

if not all_agent_frames:
return

## Combine all agent set DataFrames into one
final_agent_frame = pl.concat(all_agent_frames, how="diagonal_relaxed")

## Add metadata and append
final_agent_frame = final_agent_frame.with_columns(
[
pl.lit(current_model_step).alias("step"),
pl.lit(str(self.seed)).alias("seed"),
pl.lit(batch_id).alias("batch"),
]
)
self._frames.append(("agent", current_model_step, batch_id, agent_lazy_frame))
self._frames.append(("agent", current_model_step, batch_id, final_agent_frame))

@property
def data(self) -> dict[str, pl.DataFrame]:
Expand Down Expand Up @@ -461,13 +523,20 @@ def _validate_reporter_table_columns(
If any expected columns are missing from the table.
"""
expected_columns = set()

## Add columns required for the new long agent format
if table_name == "agent_data":
expected_columns.add("unique_id")
expected_columns.add("agent_type")

## Add all keys from the reporter dict
for col_name, required_column in reporter.items():
if isinstance(required_column, str):
for k, v in self._model.sets[required_column].items():
expected_columns.add(
(col_name + "_" + str(k.__class__.__name__)).lower()
)
if table_name == "agent_data":
if isinstance(required_column, str):
expected_columns.add(col_name.lower())
## Callables are not supported for agents
else:
## For model, all reporters are callable
expected_columns.add(col_name.lower())

query = f"""
Expand All @@ -484,10 +553,7 @@ def _validate_reporter_table_columns(

existing_columns = {row[0] for row in result}
missing_columns = expected_columns - existing_columns
required_columns = {
"step": "Integer",
"seed": "Varchar",
}
required_columns = {"step": "Integer", "seed": "Varchar", "batch": "Integer"}

missing_required = {
col: col_type
Expand Down
Loading
Loading