Skip to content

Commit ef58d6d

Browse files
committed
Add openai support for semantic parse_pdf
1 parent d44974b commit ef58d6d

15 files changed

+237
-100
lines changed

src/fenic/_inference/anthropic/anthropic_batch_chat_completions_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def _estimate_structured_output_overhead(self, response_format) -> int:
275275
"""
276276
return self.estimate_response_format_tokens(response_format)
277277

278-
def _get_max_output_tokens(self, request: FenicCompletionsRequest) -> int:
278+
def _get_max_output_token_request_limit(self, request: FenicCompletionsRequest) -> int:
279279
"""Get maximum output tokens including thinking budget.
280280
281281
Args:
@@ -329,7 +329,7 @@ def estimate_tokens_for_request(self, request: FenicCompletionsRequest):
329329
input_tokens += self._count_auxiliary_input_tokens(request)
330330

331331
# Estimate output tokens
332-
output_tokens = self._get_max_output_tokens(request)
332+
output_tokens = self._get_max_output_token_request_limit(request)
333333

334334
return TokenEstimate(
335335
input_tokens=input_tokens,

src/fenic/_inference/cohere/cohere_batch_embeddings_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def estimate_tokens_for_request(self, request: FenicEmbeddingsRequest) -> TokenE
171171
output_tokens=0
172172
)
173173

174-
def _get_max_output_tokens(self, request: FenicEmbeddingsRequest) -> int:
174+
def _get_max_output_token_request_limit(self, request: FenicEmbeddingsRequest) -> int:
175175
"""Get maximum output tokens (always 0 for embeddings).
176176
177177
Returns:

src/fenic/_inference/common_openai/openai_chat_completions_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,10 @@ async def make_single_request(
9090
common_params: dict[str, Any] = {
9191
"model": self._model,
9292
"messages": convert_messages(request.messages),
93-
"max_completion_tokens": request.max_completion_tokens + profile_configuration.expected_additional_reasoning_tokens,
9493
"n": 1,
9594
}
95+
if request.max_completion_tokens:
96+
common_params.update({"max_completion_tokens": request.max_completion_tokens + profile_configuration.expected_additional_reasoning_tokens})
9697
if request.temperature:
9798
common_params.update({"temperature": request.temperature})
9899

src/fenic/_inference/google/gemini_batch_embeddings_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def estimate_tokens_for_request(self, request: FenicEmbeddingsRequest) -> TokenE
121121
input_tokens=self.token_counter.count_tokens(request.doc), output_tokens=0
122122
)
123123

124-
def _get_max_output_tokens(self, request: FenicEmbeddingsRequest) -> int:
124+
def _get_max_output_token_request_limit(self, request: FenicEmbeddingsRequest) -> int:
125125
return 0
126126

127127
def reset_metrics(self):

src/fenic/_inference/google/gemini_native_chat_completions_client.py

Lines changed: 59 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -132,56 +132,6 @@ def count_tokens(self, messages: Tokenizable) -> int: # type: ignore[override]
132132
# Re-expose for mypy – same implementation as parent.
133133
return super().count_tokens(messages)
134134

135-
def _estimate_structured_output_overhead(self, response_format: ResolvedResponseFormat) -> int:
136-
"""Use Google-specific response schema token estimation.
137-
138-
Args:
139-
response_format: Pydantic model class defining the response format
140-
141-
Returns:
142-
Estimated token overhead for structured output
143-
"""
144-
return self._estimate_response_schema_tokens(response_format)
145-
146-
def _get_max_output_tokens(self, request: FenicCompletionsRequest) -> Optional[int]:
147-
"""Get maximum output tokens including thinking budget.
148-
149-
If max_completion_tokens is not set, return None.
150-
151-
Conservative estimate that includes both completion tokens and
152-
thinking token budget with a safety margin.
153-
154-
Args:
155-
request: The completion request
156-
157-
Returns:
158-
Maximum output tokens (completion + thinking budget with safety margin)
159-
"""
160-
if request.max_completion_tokens is None:
161-
return None
162-
profile_config = self._profile_manager.get_profile_by_name(
163-
request.model_profile
164-
)
165-
return request.max_completion_tokens + int(
166-
1.5 * profile_config.thinking_token_budget
167-
)
168-
169-
@cache # noqa: B019 – builtin cache OK here.
170-
def _estimate_response_schema_tokens(self, response_format: ResolvedResponseFormat) -> int:
171-
"""Estimate token count for a response format schema.
172-
173-
Uses Google's tokenizer to count tokens in a JSON schema representation
174-
of the response format. Results are cached for performance.
175-
176-
Args:
177-
response_format: Pydantic model class defining the response format
178-
179-
Returns:
180-
Estimated token count for the response format
181-
"""
182-
schema_str = response_format.schema_fingerprint
183-
return self._token_counter.count_tokens(schema_str)
184-
185135
def get_request_key(self, request: FenicCompletionsRequest) -> str:
186136
"""Generate a unique key for the request.
187137
@@ -196,19 +146,17 @@ def get_request_key(self, request: FenicCompletionsRequest) -> str:
196146
def estimate_tokens_for_request(self, request: FenicCompletionsRequest):
197147
"""Estimate the number of tokens for a request.
198148
149+
If the request provides a max_completion_tokens value, use that. Otherwise, estimate the output tokens based on the file size.
150+
199151
Args:
200152
request: The request to estimate tokens for
201153
202154
Returns:
203155
TokenEstimate: The estimated token usage
204156
"""
205-
206-
# Count input tokens
207157
input_tokens = self.count_tokens(request.messages)
208158
input_tokens += self._count_auxiliary_input_tokens(request)
209-
210-
output_tokens = self._get_max_output_tokens(request) or self._model_parameters.max_output_tokens
211-
159+
output_tokens = self._estimate_output_tokens(request)
212160
return TokenEstimate(input_tokens=input_tokens, output_tokens=output_tokens)
213161

214162
async def make_single_request(
@@ -228,7 +176,7 @@ async def make_single_request(
228176
"""
229177

230178
profile_config = self._profile_manager.get_profile_by_name(request.model_profile)
231-
max_output_tokens = self._get_max_output_tokens(request)
179+
max_output_tokens = self._get_max_output_token_request_limit(request)
232180

233181
generation_config: GenerateContentConfigDict = {
234182
"temperature": request.temperature,
@@ -355,3 +303,58 @@ async def make_single_request(
355303
finally:
356304
if file_obj:
357305
await delete_file(self._client, file_obj.name)
306+
307+
@cache # noqa: B019 – builtin cache OK here.
308+
def _estimate_response_schema_tokens(self, response_format: ResolvedResponseFormat) -> int:
309+
"""Estimate token count for a response format schema.
310+
311+
Uses Google's tokenizer to count tokens in a JSON schema representation
312+
of the response format. Results are cached for performance.
313+
314+
Args:
315+
response_format: Pydantic model class defining the response format
316+
317+
Returns:
318+
Estimated token count for the response format
319+
"""
320+
schema_str = response_format.schema_fingerprint
321+
return self._token_counter.count_tokens(schema_str)
322+
323+
def _estimate_structured_output_overhead(self, response_format: ResolvedResponseFormat) -> int:
324+
"""Use Google-specific response schema token estimation.
325+
326+
Args:
327+
response_format: Pydantic model class defining the response format
328+
329+
Returns:
330+
Estimated token overhead for structured output
331+
"""
332+
return self._estimate_response_schema_tokens(response_format)
333+
334+
def _estimate_output_tokens(self, request: FenicCompletionsRequest) -> int:
335+
"""Estimate the number of output tokens for a request."""
336+
estimated_output_tokens = request.max_completion_tokens or 0
337+
if request.max_completion_tokens is None and request.messages.user_file:
338+
# TODO(DY): the semantic operator should dictate how the file affects the token estimate
339+
estimated_output_tokens = self.token_counter.count_file_output_tokens(request.messages)
340+
return estimated_output_tokens + self._get_expected_additional_reasoning_tokens(request)
341+
342+
def _get_max_output_token_request_limit(self, request: FenicCompletionsRequest) -> Optional[int]:
343+
"""Get the upper limit of output tokens for a request.
344+
345+
If max_completion_tokens is not set, don't apply a limit and return None.
346+
347+
Include the thinking token budget with a safety margin."""
348+
max_output_tokens = request.max_completion_tokens or 0
349+
if request.max_completion_tokens is None and request.messages.user_file:
350+
# Guardrail to ensure the model uses a sane amount of output tokens.
351+
# TODO(DY): the semantic operator should dictate how the file affects the token estimate
352+
max_output_tokens = self.token_counter.count_file_output_tokens(request.messages) * 2
353+
return max_output_tokens + self._get_expected_additional_reasoning_tokens(request)
354+
355+
def _get_expected_additional_reasoning_tokens(self, request: FenicCompletionsRequest) -> int:
356+
"""Get the expected additional reasoning tokens for a request. Include a safety margin."""
357+
profile_config = self._profile_manager.get_profile_by_name(request.model_profile)
358+
return int(
359+
1.5 * profile_config.thinking_token_budget
360+
)

src/fenic/_inference/language_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
@dataclass
2424
class InferenceConfiguration:
25-
# If max_output_tokens is not provided, do not include it in the request.
25+
# If max_output_tokens is not provided, model_client will add a guardrail based on the estimated output tokens.
2626
max_output_tokens: Optional[int]
2727
temperature: float
2828
top_logprobs: Optional[int] = None

src/fenic/_inference/model_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,8 @@ def _estimate_structured_output_overhead(self, response_format: ResolvedResponse
245245

246246

247247
@abstractmethod
248-
def _get_max_output_tokens(self, request: RequestT) -> int:
249-
"""Get conservative output token estimate. Override in subclasses for provider-specific logic."""
248+
def _get_max_output_token_request_limit(self, request: RequestT) -> int:
249+
"""Get the upper limit of output tokens to set on a request."""
250250
pass
251251

252252
#

src/fenic/_inference/openai/openai_batch_chat_completions_client.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
profile_configurations=profiles,
6666
default_profile_name=default_profile_name,
6767
)
68+
6869
self._core = OpenAIChatCompletionsCore(
6970
model=model,
7071
model_provider=ModelProvider.OPENAI,
@@ -108,7 +109,7 @@ def estimate_tokens_for_request(self, request: FenicCompletionsRequest) -> Token
108109
"""
109110
return TokenEstimate(
110111
input_tokens=self.token_counter.count_tokens(request.messages),
111-
output_tokens=self._get_max_output_tokens(request)
112+
output_tokens=self._estimate_output_tokens(request)
112113
)
113114

114115
def reset_metrics(self):
@@ -123,10 +124,24 @@ def get_metrics(self) -> LMMetrics:
123124
"""
124125
return self._core.get_metrics()
125126

126-
def _get_max_output_tokens(self, request: FenicCompletionsRequest) -> int:
127-
"""Conservative estimate: max_completion_tokens + reasoning effort-based thinking tokens."""
128-
base_tokens = request.max_completion_tokens
129-
130-
# Get profile-specific reasoning effort
127+
def _estimate_output_tokens(self, request: FenicCompletionsRequest) -> int:
128+
"""Estimate the number of output tokens for a request."""
129+
base_tokens = request.max_completion_tokens or 0
130+
if request.max_completion_tokens is None and request.messages.user_file:
131+
# TODO(DY): the semantic operator should dictate how the file affects the token estimate
132+
base_tokens += self.token_counter.count_file_output_tokens(messages=request.messages)
133+
return base_tokens + self._get_expected_additional_reasoning_tokens(request)
134+
135+
def _get_max_output_token_request_limit(self, request: FenicCompletionsRequest) -> int:
136+
"""Return the maximum output token limit for a request."""
137+
max_output_tokens = request.max_completion_tokens or 0
138+
if request.max_completion_tokens is None and request.messages.user_file:
139+
# Guardrail to ensure the model uses a sane amount of output tokens.
140+
# TODO(DY): the semantic operator should dictate how the file affects the token estimate
141+
max_output_tokens = self.token_counter.count_file_output_tokens(request.messages) * 2
142+
return max_output_tokens + self._get_expected_additional_reasoning_tokens(request)
143+
144+
def _get_expected_additional_reasoning_tokens(self, request: FenicCompletionsRequest) -> int:
145+
"""Get the expected additional reasoning tokens for a request."""
131146
profile_config = self._profile_manager.get_profile_by_name(request.model_profile)
132-
return base_tokens + profile_config.expected_additional_reasoning_tokens
147+
return profile_config.expected_additional_reasoning_tokens

src/fenic/_inference/openai/openai_batch_embeddings_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def get_metrics(self) -> RMMetrics:
107107
"""
108108
return self._core.get_metrics()
109109

110-
def _get_max_output_tokens(self, request: RequestT) -> int:
110+
def _get_max_output_token_request_limit(self, request: RequestT) -> int:
111111
return 0
112112

113113
async def validate_api_key(self):

src/fenic/_inference/openrouter/openrouter_batch_chat_completions_client.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ async def make_single_request(
9494
common_params = {
9595
"model": self.model,
9696
"messages": convert_messages(request.messages),
97-
"max_completion_tokens": self._get_max_output_tokens(request),
97+
"max_completion_tokens": self._get_max_output_token_request_limit(request),
9898
"n": 1,
9999
}
100100

@@ -239,7 +239,7 @@ def estimate_tokens_for_request(
239239
) -> TokenEstimate:
240240
return TokenEstimate(
241241
input_tokens=self.token_counter.count_tokens(request.messages),
242-
output_tokens=self._get_max_output_tokens(request),
242+
output_tokens=self.token_counter.count_tokens(request.messages) + self._get_expected_additional_reasoning_tokens(request),
243243
)
244244

245245
def reset_metrics(self):
@@ -248,7 +248,14 @@ def reset_metrics(self):
248248
def get_metrics(self) -> LMMetrics:
249249
return self._metrics
250250

251-
def _get_max_output_tokens(self, request: FenicCompletionsRequest) -> int:
251+
def _get_max_output_token_request_limit(self, request: FenicCompletionsRequest) -> int:
252+
"""Get the upper limit of output tokens for a request.
253+
254+
If max_completion_tokens is not set, don't apply a limit and return None.
255+
256+
Include the thinking token budget with a safety margin."""
257+
if request.max_completion_tokens is None:
258+
return None
252259
return request.max_completion_tokens + self._get_expected_additional_reasoning_tokens(request)
253260

254261
# This is a slightly less conservative estimate than the OpenRouter documentation on how reasoning_effort is used to

0 commit comments

Comments
 (0)