Skip to content

Commit c74b8cc

Browse files
committed
adding many tests for validation and tool creation in general, fix some issues that came up
1 parent 507902a commit c74b8cc

File tree

6 files changed

+241
-65
lines changed

6 files changed

+241
-65
lines changed

src/fenic/api/mcp/tool_generation.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -466,15 +466,15 @@ def _auto_generate_sql_tool(
466466
"""Create an Analyze tool that executes DuckDB SELECT SQL across datasets.
467467
468468
- JOINs between the provided datasets are allowed.
469-
- DDL/DML, CTEs, subqueries, UNION, and multiple top-level queries are not allowed (enforced upstream).
469+
- DDL/DML, and multiple top-level queries are not allowed (enforced in `session.sql()`)
470470
- The callable returns a LogicalPlan gathered later by the MCP server.
471471
"""
472472
if len(datasets) == 0:
473473
raise ConfigurationError("Cannot create SQL tool: no datasets provided.")
474474

475475
async def analyze_func(
476476
full_sql: Annotated[
477-
str, "Full SELECT SQL. Refer to DataFrames by name in braces, e.g., `SELECT * FROM {orders}`. JOINs between the provided datasets are allowed. SQL dialect: DuckDB. DDL/DML, CTEs, subqueries, UNION, and multiple top-level queries are not allowed"]
477+
str, "Full SELECT SQL. Refer to DataFrames by name in braces, e.g., `SELECT * FROM {orders}`. JOINs between the provided datasets are allowed. SQL dialect: DuckDB. DDL/DML, and multiple top-level queries are not allowed"]
478478
) -> LogicalPlan:
479479
return session.sql(full_sql.strip(), **{spec.table_name: spec.df for spec in datasets})._logical_plan
480480

@@ -828,8 +828,8 @@ def _auto_generate_core_tools(
828828
datasets,
829829
session,
830830
tool_name=f"{tool_group_name} - Schema",
831-
tool_description="\n\n".join([
832-
"Show the schema (column names and types) for any or all of the datasets listed below. This call should be the first step in exploring the available datasets.",
831+
tool_description="\n".join([
832+
"Show the schema (column names and types) for any or all of the datasets listed below. This call should be the first step in exploring the available datasets:",
833833
group_desc,
834834
]),
835835
)
@@ -852,7 +852,7 @@ def _auto_generate_core_tools(
852852
datasets,
853853
session,
854854
tool_name=f"{tool_group_name} - Read",
855-
tool_description="\n\n".join([
855+
tool_description="\n".join([
856856
"Read rows from a single dataset. Use to sample data, or to execute simple queries over the data that do not require filtering or grouping.",
857857
"Use `include_columns` and `exclude_columns` to filter columns by name -- this is important to conserve token usage. Use the `Profile` tool to understand the columns and their sizes.",
858858
"Available datasets:",
@@ -865,7 +865,7 @@ def _auto_generate_core_tools(
865865
datasets,
866866
session,
867867
tool_name=f"{tool_group_name} - Search Summary",
868-
tool_description="\n\n".join([
868+
tool_description="\n".join([
869869
"Perform a substring/regex search across all datasets and return a summary of the number of matches per dataset.",
870870
"Available datasets:",
871871
group_desc,
@@ -875,7 +875,7 @@ def _auto_generate_core_tools(
875875
datasets,
876876
session,
877877
tool_name=f"{tool_group_name} - Search Content",
878-
tool_description="\n\n".join([
878+
tool_description="\n".join([
879879
"Return matching rows from a single dataset using substring/regex across string columns.",
880880
"Available datasets:",
881881
group_desc,
@@ -887,9 +887,9 @@ def _auto_generate_core_tools(
887887
datasets,
888888
session,
889889
tool_name=f"{tool_group_name} - Analyze",
890-
tool_description="\n\n".join([
890+
tool_description="\n".join([
891891
"Execute Read-Only (SELECT) SQL over the provided datasets using fenic's SQL support.",
892-
"DDL/DML, CTEs, subqueries, UNION, and multiple top-level queries are not allowed (enforced upstream).",
892+
"DDL/DML, and multiple top-level queries are not allowed.",
893893
"For text search, prefer regular expressions (REGEXP_MATCHES()/REGEXP_EXTRACT()).",
894894
"Paging: use ORDER BY to define row order, then LIMIT and OFFSET for pages.",
895895
"JOINs between datasets are allowed. Refer to datasets by name in braces, e.g., {orders}.",

src/fenic/core/mcp/_tools.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from fenic.core._logical_plan.plans.base import LogicalPlan
1414
from fenic.core._utils.type_inference import infer_pytype_from_dtype
1515
from fenic.core.error import PlanError
16-
from fenic.core.mcp._validators import get_param_validator, maybe_get_param_validator
16+
from fenic.core.mcp._validators import get_param_validator
1717
from fenic.core.mcp.types import (
1818
BoundToolParam,
1919
ParameterizedToolDefinition,
@@ -85,9 +85,10 @@ def bind_tool(
8585
try:
8686
validator = get_param_validator(validator_name)
8787
if unresolved_expr.data_type not in validator.data_types():
88+
supported_data_types = ", ".join([str(dt) for dt in validator.data_types()])
8889
raise PlanError(
89-
f"Param Validator {validator_name} supports data types {validator.data_types()}, "
90-
f"but the parameter {unresolved_expr_name} has data type {unresolved_expr.data_type}."
90+
f"Param Validator `{validator_name}` supports data types ({supported_data_types}), "
91+
f"but the parameter `{unresolved_expr_name}` has data type {unresolved_expr.data_type}."
9192
)
9293
validators.append(validator)
9394
except KeyError:
@@ -132,19 +133,11 @@ def _infer_base_type(p: BoundToolParam):
132133
if isinstance(p.data_type, ArrayType):
133134
return list[literal_type] # type: ignore[valid-type]
134135
return literal_type
136+
if isinstance(p.data_type, ArrayType):
137+
inner_type = infer_pytype_from_dtype(p.data_type.element_type)
138+
return list[inner_type] # type: ignore[valid-type]
135139
return infer_pytype_from_dtype(p.data_type)
136140

137-
def _wrap_with_validator(base_t, validator_name: Optional[str]):
138-
if not validator_name:
139-
return base_t
140-
pv = maybe_get_param_validator(validator_name)
141-
if pv is None:
142-
return base_t
143-
def _wrap(v, _pv=pv):
144-
_pv.validate(v)
145-
return v
146-
return TypingAnnotated[base_t, AfterValidator(_wrap)] # type: ignore[valid-type]
147-
148141
def _field_kwargs(p: BoundToolParam, include_default: bool) -> dict:
149142
kwargs: dict = {"description": p.description}
150143
constraints = p.constraints

src/fenic/core/mcp/_validators.py

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import re
2-
from typing import Dict, List, Optional, Protocol, Union, runtime_checkable
2+
from typing import Dict, List, Protocol, Union, runtime_checkable
33

4+
from fenic._polars_plugins import py_validate_regex # noqa: F401
45
from fenic.core.error import (
56
ValidationError,
67
)
@@ -21,7 +22,7 @@ def data_types(self) -> List[DataType]:
2122
"""The data types that the validator operates on."""
2223
...
2324

24-
def validate(self, value: Union[str, int, float, bool, list, dict]) -> bool:
25+
def validate(self, value: Union[str, int, float, bool, list, dict]):
2526
"""Validate an argument value.
2627
2728
Args:
@@ -62,14 +63,6 @@ def validate(self, user_query: str):
6263
if len(query) > MAX_REGEX_LENGTH:
6364
raise ValidationError(f"Regex too long (>{MAX_REGEX_LENGTH} characters)")
6465

65-
# Support /pattern/flags and capture flags
66-
query, flags = self._strip_slash_delimiters(query)
67-
unsupported_flags = {f for f in flags if f not in {"i", "m", "s", "x"}}
68-
if unsupported_flags:
69-
raise ValidationError(
70-
f"Unsupported regex flags: {''.join(sorted(unsupported_flags))}"
71-
)
72-
7366
# Strip inline flags at start like (?i), (?m), combined, to avoid duplication
7467
query = re.sub(r"^\(\?[aiLmsux]+\)", "", query)
7568

@@ -89,15 +82,16 @@ def validate(self, user_query: str):
8982
except ValueError:
9083
raise ValidationError("Invalid quantifier bounds") from None
9184
if m_val > MAX_QUANTIFIER_VALUE or n_val > MAX_QUANTIFIER_VALUE:
92-
raise ValidationError("Quantifier bounds too large")
85+
raise ValidationError(f"Quantifier bounds {m_val} or {n_val} > {MAX_QUANTIFIER_VALUE}")
9386
if n and n_val < m_val:
94-
raise ValidationError("Quantifier upper bound less than lower bound")
87+
raise ValidationError(f"Quantifier upper bound {n_val} < lower bound {m_val}")
9588

9689
# Limit alternations
97-
if query.count("|") > MAX_ALTERNATIONS:
98-
raise ValidationError("Too many alternations in regex")
90+
alternations = query.count("|")
91+
if alternations > MAX_ALTERNATIONS:
92+
raise ValidationError(f"Too many alternations ({alternations} > {MAX_ALTERNATIONS})")
9993

100-
# Disallow backreferences (simple and robust detection)
94+
# Disallow backreferences
10195
if any(f"\\{d}" in query for d in "123456789"):
10296
raise ValidationError("Backreferences are not supported")
10397

@@ -121,11 +115,11 @@ def validate(self, user_query: str):
121115
if re.search(r"\{\s*\d+\s*,\s*\d+\s*,", query):
122116
raise ValidationError("Invalid quantifier syntax")
123117

124-
# Ensure it compiles in Python as a basic sanity check
118+
# Final check, ensure that the regex is valid for `rlike`
125119
try:
126-
re.compile(query)
127-
except re.error as err:
128-
raise ValidationError(f"Invalid regex syntax: {err}") from None
120+
py_validate_regex(query)
121+
except Exception as err:
122+
raise ValidationError(f"Invalid regex syntax: {query}") from err
129123

130124
return
131125

@@ -146,20 +140,6 @@ def _is_balanced(self, s: str, open_char: str, close_char: str) -> bool:
146140
i += 1
147141
return depth == 0
148142

149-
150-
def _strip_slash_delimiters(self, pattern: str) -> tuple[str, set[str]]:
151-
"""Support /pattern/flags syntax; return (pattern, flags).
152-
153-
Only recognize i,m,s,x flags; others are rejected later.
154-
"""
155-
if len(pattern) >= 2 and pattern.startswith("/") and pattern.rfind("/") > 0:
156-
last = pattern.rfind("/")
157-
core = pattern[1:last]
158-
flags = set(pattern[last + 1 :].lower())
159-
return core, flags
160-
return pattern, set()
161-
162-
163143
# -- Registry for reusable ParamValidators --
164144
_PARAM_VALIDATOR_REGISTRY: Dict[str, ParamValidator] = {}
165145

@@ -186,11 +166,5 @@ def get_param_validator(name: str) -> ParamValidator:
186166
raise KeyError(f"No ParamValidator registered under name '{name}'") from err
187167

188168

189-
def maybe_get_param_validator(name: Optional[str]) -> Optional[ParamValidator]:
190-
if name is None:
191-
return None
192-
return get_param_validator(name)
193-
194-
195169
# Pre-register common validators
196170
register_param_validator("regex", RegexValidator())

tests/api/mcp/test_server.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,59 @@
22

33
import pytest
44

5+
from fenic.api.functions import col, tool_param
56
from fenic.api.mcp.server import create_mcp_server
67
from fenic.api.mcp.tool_generation import auto_generate_core_tools_from_tables
78
from fenic.api.session.session import Session
89
from fenic.core._utils.misc import to_snake_case
10+
from fenic.core.mcp._tools import bind_tool
11+
from fenic.core.mcp.types import ToolParam, ToolParamConstraints
12+
from fenic.core.types.datatypes import ArrayType, IntegerType, StringType
913
from tests.api.mcp.utils import create_table_with_rows
1014

1115

16+
def test_server_generation_with_parameterized_tools(local_session: Session):
17+
pytest.importorskip("fastmcp")
18+
df = local_session.create_dataframe({"city": ["SF"], "age": [10], "user_name": ["Alice"]})
19+
query = df.filter(
20+
(col("city") == tool_param("city_name", StringType))
21+
& (col("age") >= tool_param("age", IntegerType))
22+
& (col("user_name").is_in(tool_param("user_names", ArrayType(StringType))))
23+
)._logical_plan
24+
25+
parameterized_tool = bind_tool(
26+
name="tool_x",
27+
description="table one",
28+
params=[
29+
ToolParam(name="city_name", description="City name", constraints=ToolParamConstraints(pattern="^SF$")),
30+
ToolParam(name="age", description="Age", constraints=ToolParamConstraints(gt=0, lt=120, multiple_of=2)),
31+
ToolParam(name="user_names", description="User names", constraints=ToolParamConstraints(min_length=1, max_length=5)),
32+
],
33+
result_limit=10,
34+
query=query,
35+
)
36+
37+
server = create_mcp_server(local_session, "Test Server", parameterized_tools=[parameterized_tool])
38+
server_tools = asyncio.run(server.mcp.get_tools())
39+
assert len(server_tools) == 1
40+
parameter_schema = server_tools["tool_x"].parameters['properties']
41+
city_name_param = parameter_schema['city_name']
42+
assert city_name_param['type'] == 'string'
43+
assert city_name_param['pattern'] == '^SF$'
44+
assert city_name_param['description'] == "City name"
45+
age_param = parameter_schema['age']
46+
assert age_param['type'] == 'integer'
47+
assert age_param['exclusiveMinimum'] == 0
48+
assert age_param['exclusiveMaximum'] == 120
49+
assert age_param['multipleOf'] == 2
50+
assert age_param['description'] == "Age"
51+
user_names_param = parameter_schema['user_names']
52+
assert user_names_param['type'] == 'array'
53+
assert user_names_param['items']['type'] == 'string'
54+
assert user_names_param['maxItems'] == 5
55+
assert user_names_param['minItems'] == 1
56+
assert user_names_param['description'] == "User names"
57+
1258
def test_server_generation(local_session: Session):
1359
pytest.importorskip("fastmcp")
1460
create_table_with_rows(local_session, "t1", [1, 2, 3], description="table one")

tests/core/mcp/test_tools.py

Lines changed: 101 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
import re
12

23
import pytest
34
from pydantic import BaseModel
45
from pydantic import ValidationError as PydValidationError
56

67
from fenic.api.functions import col, tool_param
7-
from fenic.core.error import PlanError
8+
from fenic.core.error import PlanError, ValidationError
89
from fenic.core.mcp._tools import bind_tool, create_pydantic_model_for_tool
9-
from fenic.core.mcp.types import ToolParam
10-
from fenic.core.types.datatypes import IntegerType, StringType
10+
from fenic.core.mcp.types import ToolParam, ToolParamConstraints
11+
from fenic.core.types.datatypes import ArrayType, IntegerType, StringType
1112

1213

1314
def test_toolparam_required_and_default_validation():
@@ -51,6 +52,103 @@ def test_resolve_tool_validates_unresolved_params(local_session):
5152
query=query,
5253
)
5354

55+
def test_resolve_tool_validates_mistyped_validators(local_session):
56+
df = local_session.create_dataframe({"name": ["Alice", "Bob"], "age": [25, 30], "city": ["SF", "SEA"]})
57+
query = df.filter((col("age") >= tool_param("min_age", IntegerType)) & (col("city") == tool_param("city_name", StringType)))._logical_plan
58+
59+
with pytest.raises(PlanError, match="Param Validator `regex` supports data types \(StringType\), but the parameter `min_age` has data type IntegerType."):
60+
bind_tool(
61+
name="users_by_city",
62+
description="Filter users",
63+
params=[
64+
ToolParam(name="min_age", description="Minimum age", validator_names=["regex"]),
65+
ToolParam(name="city_name", description="City name", validator_names=["regex"]),
66+
],
67+
result_limit=50,
68+
query=query,
69+
)
70+
71+
def test_resolve_tool_validates_missing_validators(local_session):
72+
df = local_session.create_dataframe({"name": ["Alice", "Bob"], "age": [25, 30], "city": ["SF", "SEA"]})
73+
query = df.filter((col("age") >= tool_param("min_age", IntegerType)) & (col("city") == tool_param("city_name", StringType)))._logical_plan
74+
75+
with pytest.raises(PlanError, match="Could not find a ParamValidator for the following validator names: \['non_existent'\]"):
76+
bind_tool(
77+
name="users_by_city",
78+
description="Filter users",
79+
params=[
80+
ToolParam(name="min_age", description="Minimum age"),
81+
ToolParam(name="city_name", description="City name", validator_names=["non_existent"]),
82+
],
83+
result_limit=50,
84+
query=query,
85+
)
86+
87+
def test_create_pydantic_model_for_tool_applies_validators(local_session):
88+
df = local_session.create_dataframe({"name": ["Alice", "Bob"], "age": [25, 30], "city": ["SF", "SEA"]})
89+
query = df.filter(
90+
(col("age") >= tool_param("min_age", IntegerType)) &
91+
(col("city") == tool_param("city_name", StringType))
92+
)._logical_plan
93+
94+
tool = bind_tool(
95+
name="users_by_city",
96+
description="Filter users",
97+
params=[
98+
ToolParam(name="min_age", description="Minimum age"),
99+
ToolParam(name="city_name", description="City name", validator_names=["regex"]),
100+
],
101+
result_limit=50,
102+
query=query,
103+
)
104+
105+
Model: type[BaseModel] = create_pydantic_model_for_tool(tool)
106+
107+
with pytest.raises(ValidationError, match="Unbalanced curly braces"):
108+
Model(city_name="{+---", min_age=25)
109+
110+
with pytest.raises(ValidationError, match="Too many alternations \(21 > 20\)"):
111+
Model(city_name="SF|SEA|OAK|PHX|LAS|ORD|XRD|PRD|IAD|CRD|FRA|LON|UMEA|BOS|YYZ|DOG|BAT|BAN|LAP|LAX|TYO|HND", min_age=25)
112+
113+
114+
def test_create_pydantic_model_for_tool_applies_field_validators(local_session):
115+
df = local_session.create_dataframe({"city": ["SF"], "age": [10], "user_name": ["Alice"]})
116+
query = df.filter(
117+
(col("city") == tool_param("city_name", StringType))
118+
& (col("age") >= tool_param("age", IntegerType))
119+
& (col("user_name").is_in(tool_param("user_names", ArrayType(StringType))))
120+
)._logical_plan
121+
122+
tool = bind_tool(
123+
name="tool_x",
124+
description="",
125+
params=[
126+
ToolParam(name="city_name", description="City name", constraints=ToolParamConstraints(pattern="^SF$")),
127+
ToolParam(name="age", description="Age", constraints=ToolParamConstraints(gt=0, lt=120, multiple_of=2)),
128+
ToolParam(name="user_names", description="User names", constraints=ToolParamConstraints(min_length=1, max_length=5)),
129+
],
130+
result_limit=10,
131+
query=query,
132+
)
133+
134+
Model: type[BaseModel] = create_pydantic_model_for_tool(tool)
135+
#should pass validation
136+
Model(city_name="SF", age=10, user_names=["Alice", "Bob"])
137+
with pytest.raises(PydValidationError, match=re.escape("String should match pattern '^SF$'")):
138+
Model(city_name="SEA", age=10, user_names=["Alice", "Bob"])
139+
140+
with pytest.raises(PydValidationError, match=re.escape("Input should be greater than 0")):
141+
Model(city_name="SF", age=0, user_names=["Alice", "Bob"])
142+
143+
with pytest.raises(PydValidationError, match=re.escape("Input should be a multiple of 2")):
144+
Model(city_name="SF", age=11, user_names=["Alice", "Bob"])
145+
146+
with pytest.raises(PydValidationError, match=re.escape("List should have at most 5 items after validation, not 6")):
147+
Model(city_name="SF", age=10, user_names=["Alice", "Bob", "Charlie", "David", "Eve", "Frank"])
148+
149+
with pytest.raises(PydValidationError, match=re.escape("List should have at least 1 item after validation, not 0")):
150+
Model(city_name="SF", age=10, user_names=[])
151+
54152

55153
def test_create_pydantic_model_for_tool_defaults_and_required(local_session):
56154
df = local_session.create_dataframe({"city": ["SF"], "age": [10]})

0 commit comments

Comments
 (0)