|
| 1 | +import re |
1 | 2 |
|
2 | 3 | import pytest |
3 | 4 | from pydantic import BaseModel |
4 | 5 | from pydantic import ValidationError as PydValidationError |
5 | 6 |
|
6 | 7 | from fenic.api.functions import col, tool_param |
7 | | -from fenic.core.error import PlanError |
| 8 | +from fenic.core.error import PlanError, ValidationError |
8 | 9 | from fenic.core.mcp._tools import bind_tool, create_pydantic_model_for_tool |
9 | | -from fenic.core.mcp.types import ToolParam |
10 | | -from fenic.core.types.datatypes import IntegerType, StringType |
| 10 | +from fenic.core.mcp.types import ToolParam, ToolParamConstraints |
| 11 | +from fenic.core.types.datatypes import ArrayType, IntegerType, StringType |
11 | 12 |
|
12 | 13 |
|
13 | 14 | def test_toolparam_required_and_default_validation(): |
@@ -51,6 +52,103 @@ def test_resolve_tool_validates_unresolved_params(local_session): |
51 | 52 | query=query, |
52 | 53 | ) |
53 | 54 |
|
| 55 | +def test_resolve_tool_validates_mistyped_validators(local_session): |
| 56 | + df = local_session.create_dataframe({"name": ["Alice", "Bob"], "age": [25, 30], "city": ["SF", "SEA"]}) |
| 57 | + query = df.filter((col("age") >= tool_param("min_age", IntegerType)) & (col("city") == tool_param("city_name", StringType)))._logical_plan |
| 58 | + |
| 59 | + with pytest.raises(PlanError, match="Param Validator `regex` supports data types \(StringType\), but the parameter `min_age` has data type IntegerType."): |
| 60 | + bind_tool( |
| 61 | + name="users_by_city", |
| 62 | + description="Filter users", |
| 63 | + params=[ |
| 64 | + ToolParam(name="min_age", description="Minimum age", validator_names=["regex"]), |
| 65 | + ToolParam(name="city_name", description="City name", validator_names=["regex"]), |
| 66 | + ], |
| 67 | + result_limit=50, |
| 68 | + query=query, |
| 69 | + ) |
| 70 | + |
| 71 | +def test_resolve_tool_validates_missing_validators(local_session): |
| 72 | + df = local_session.create_dataframe({"name": ["Alice", "Bob"], "age": [25, 30], "city": ["SF", "SEA"]}) |
| 73 | + query = df.filter((col("age") >= tool_param("min_age", IntegerType)) & (col("city") == tool_param("city_name", StringType)))._logical_plan |
| 74 | + |
| 75 | + with pytest.raises(PlanError, match="Could not find a ParamValidator for the following validator names: \['non_existent'\]"): |
| 76 | + bind_tool( |
| 77 | + name="users_by_city", |
| 78 | + description="Filter users", |
| 79 | + params=[ |
| 80 | + ToolParam(name="min_age", description="Minimum age"), |
| 81 | + ToolParam(name="city_name", description="City name", validator_names=["non_existent"]), |
| 82 | + ], |
| 83 | + result_limit=50, |
| 84 | + query=query, |
| 85 | + ) |
| 86 | + |
| 87 | +def test_create_pydantic_model_for_tool_applies_validators(local_session): |
| 88 | + df = local_session.create_dataframe({"name": ["Alice", "Bob"], "age": [25, 30], "city": ["SF", "SEA"]}) |
| 89 | + query = df.filter( |
| 90 | + (col("age") >= tool_param("min_age", IntegerType)) & |
| 91 | + (col("city") == tool_param("city_name", StringType)) |
| 92 | + )._logical_plan |
| 93 | + |
| 94 | + tool = bind_tool( |
| 95 | + name="users_by_city", |
| 96 | + description="Filter users", |
| 97 | + params=[ |
| 98 | + ToolParam(name="min_age", description="Minimum age"), |
| 99 | + ToolParam(name="city_name", description="City name", validator_names=["regex"]), |
| 100 | + ], |
| 101 | + result_limit=50, |
| 102 | + query=query, |
| 103 | + ) |
| 104 | + |
| 105 | + Model: type[BaseModel] = create_pydantic_model_for_tool(tool) |
| 106 | + |
| 107 | + with pytest.raises(ValidationError, match="Unbalanced curly braces"): |
| 108 | + Model(city_name="{+---", min_age=25) |
| 109 | + |
| 110 | + with pytest.raises(ValidationError, match="Too many alternations \(21 > 20\)"): |
| 111 | + Model(city_name="SF|SEA|OAK|PHX|LAS|ORD|XRD|PRD|IAD|CRD|FRA|LON|UMEA|BOS|YYZ|DOG|BAT|BAN|LAP|LAX|TYO|HND", min_age=25) |
| 112 | + |
| 113 | + |
| 114 | +def test_create_pydantic_model_for_tool_applies_field_validators(local_session): |
| 115 | + df = local_session.create_dataframe({"city": ["SF"], "age": [10], "user_name": ["Alice"]}) |
| 116 | + query = df.filter( |
| 117 | + (col("city") == tool_param("city_name", StringType)) |
| 118 | + & (col("age") >= tool_param("age", IntegerType)) |
| 119 | + & (col("user_name").is_in(tool_param("user_names", ArrayType(StringType)))) |
| 120 | + )._logical_plan |
| 121 | + |
| 122 | + tool = bind_tool( |
| 123 | + name="tool_x", |
| 124 | + description="", |
| 125 | + params=[ |
| 126 | + ToolParam(name="city_name", description="City name", constraints=ToolParamConstraints(pattern="^SF$")), |
| 127 | + ToolParam(name="age", description="Age", constraints=ToolParamConstraints(gt=0, lt=120, multiple_of=2)), |
| 128 | + ToolParam(name="user_names", description="User names", constraints=ToolParamConstraints(min_length=1, max_length=5)), |
| 129 | + ], |
| 130 | + result_limit=10, |
| 131 | + query=query, |
| 132 | + ) |
| 133 | + |
| 134 | + Model: type[BaseModel] = create_pydantic_model_for_tool(tool) |
| 135 | + #should pass validation |
| 136 | + Model(city_name="SF", age=10, user_names=["Alice", "Bob"]) |
| 137 | + with pytest.raises(PydValidationError, match=re.escape("String should match pattern '^SF$'")): |
| 138 | + Model(city_name="SEA", age=10, user_names=["Alice", "Bob"]) |
| 139 | + |
| 140 | + with pytest.raises(PydValidationError, match=re.escape("Input should be greater than 0")): |
| 141 | + Model(city_name="SF", age=0, user_names=["Alice", "Bob"]) |
| 142 | + |
| 143 | + with pytest.raises(PydValidationError, match=re.escape("Input should be a multiple of 2")): |
| 144 | + Model(city_name="SF", age=11, user_names=["Alice", "Bob"]) |
| 145 | + |
| 146 | + with pytest.raises(PydValidationError, match=re.escape("List should have at most 5 items after validation, not 6")): |
| 147 | + Model(city_name="SF", age=10, user_names=["Alice", "Bob", "Charlie", "David", "Eve", "Frank"]) |
| 148 | + |
| 149 | + with pytest.raises(PydValidationError, match=re.escape("List should have at least 1 item after validation, not 0")): |
| 150 | + Model(city_name="SF", age=10, user_names=[]) |
| 151 | + |
54 | 152 |
|
55 | 153 | def test_create_pydantic_model_for_tool_defaults_and_required(local_session): |
56 | 154 | df = local_session.create_dataframe({"city": ["SF"], "age": [10]}) |
|
0 commit comments