Skip to content

Commit fd612f2

Browse files
committed
coverage: bring back coverage after rebase
1 parent b9373fe commit fd612f2

File tree

4 files changed

+66
-62
lines changed

4 files changed

+66
-62
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -212,29 +212,31 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
212212
content = await response.aread()
213213
metadata = ProtectedResourceMetadata.model_validate_json(content)
214214
self.context.protected_resource_metadata = metadata
215-
if metadata.authorization_servers:
215+
if metadata.authorization_servers: # pragma: no branch
216216
self.context.auth_server_url = str(metadata.authorization_servers[0])
217217
return True
218218

219-
except ValidationError:
219+
except ValidationError: # pragma: no cover
220220
# Invalid metadata - try next URL
221221
logger.warning(f"Invalid protected resource metadata at {response.request.url}")
222222
return False
223-
elif response.status_code == 404:
223+
elif response.status_code == 404: # pragma: no cover
224224
# Not found - try next URL in fallback chain
225225
logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL")
226226
return False
227227
else:
228228
# Other error - fail immediately
229-
raise OAuthFlowError(f"Protected Resource Metadata request failed: {response.status_code}")
229+
raise OAuthFlowError(
230+
f"Protected Resource Metadata request failed: {response.status_code}"
231+
) # pragma: no cover
230232

231233
async def _register_client(self) -> httpx.Request | None:
232234
"""Build registration request or skip if already registered."""
233235
if self.context.client_info:
234236
return None
235237

236238
if self.context.oauth_metadata and self.context.oauth_metadata.registration_endpoint:
237-
registration_url = str(self.context.oauth_metadata.registration_endpoint)
239+
registration_url = str(self.context.oauth_metadata.registration_endpoint) # pragma: no cover
238240
else:
239241
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
240242
registration_url = urljoin(auth_base_url, "/register")
@@ -254,20 +256,20 @@ async def _perform_authorization(self) -> httpx.Request:
254256
async def _perform_authorization_code_grant(self) -> tuple[str, str]:
255257
"""Perform the authorization redirect and get auth code."""
256258
if self.context.client_metadata.redirect_uris is None:
257-
raise OAuthFlowError("No redirect URIs provided for authorization code grant")
259+
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
258260
if not self.context.redirect_handler:
259-
raise OAuthFlowError("No redirect handler provided for authorization code grant")
261+
raise OAuthFlowError("No redirect handler provided for authorization code grant") # pragma: no cover
260262
if not self.context.callback_handler:
261-
raise OAuthFlowError("No callback handler provided for authorization code grant")
263+
raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover
262264

263265
if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint:
264-
auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint)
266+
auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) # pragma: no cover
265267
else:
266268
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
267269
auth_endpoint = urljoin(auth_base_url, "/authorize")
268270

269271
if not self.context.client_info:
270-
raise OAuthFlowError("No client info available for authorization")
272+
raise OAuthFlowError("No client info available for authorization") # pragma: no cover
271273

272274
# Generate PKCE parameters
273275
pkce_params = PKCEParameters.generate()
@@ -284,9 +286,9 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
284286

285287
# Only include resource param if conditions are met
286288
if self.context.should_include_resource_param(self.context.protocol_version):
287-
auth_params["resource"] = self.context.get_resource_url() # RFC 8707
289+
auth_params["resource"] = self.context.get_resource_url() # RFC 8707 # pragma: no cover
288290

289-
if self.context.client_metadata.scope:
291+
if self.context.client_metadata.scope: # pragma: no branch
290292
auth_params["scope"] = self.context.client_metadata.scope
291293

292294
authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
@@ -296,10 +298,10 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
296298
auth_code, returned_state = await self.context.callback_handler()
297299

298300
if returned_state is None or not secrets.compare_digest(returned_state, state):
299-
raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}")
301+
raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") # pragma: no cover
300302

301303
if not auth_code:
302-
raise OAuthFlowError("No authorization code received")
304+
raise OAuthFlowError("No authorization code received") # pragma: no cover
303305

304306
# Return auth code and code verifier for token exchange
305307
return auth_code, pkce_params.code_verifier
@@ -317,9 +319,9 @@ async def _exchange_token_authorization_code(
317319
) -> httpx.Request:
318320
"""Build token exchange request for authorization_code flow."""
319321
if self.context.client_metadata.redirect_uris is None:
320-
raise OAuthFlowError("No redirect URIs provided for authorization code grant")
322+
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
321323
if not self.context.client_info:
322-
raise OAuthFlowError("Missing client info")
324+
raise OAuthFlowError("Missing client info") # pragma: no cover
323325

324326
token_url = self._get_token_endpoint()
325327
token_data = token_data or {}
@@ -347,9 +349,9 @@ async def _exchange_token_authorization_code(
347349
async def _handle_token_response(self, response: httpx.Response) -> None:
348350
"""Handle token exchange response."""
349351
if response.status_code != 200:
350-
body = await response.aread()
351-
body_text = body.decode("utf-8")
352-
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}")
352+
body = await response.aread() # pragma: no cover
353+
body_text = body.decode("utf-8") # pragma: no cover
354+
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") # pragma: no cover
353355

354356
# Parse and validate response with scope validation
355357
token_response = await handle_token_response_scopes(response)
@@ -362,13 +364,13 @@ async def _handle_token_response(self, response: httpx.Response) -> None:
362364
async def _refresh_token(self) -> httpx.Request:
363365
"""Build token refresh request."""
364366
if not self.context.current_tokens or not self.context.current_tokens.refresh_token:
365-
raise OAuthTokenError("No refresh token available")
367+
raise OAuthTokenError("No refresh token available") # pragma: no cover
366368

367369
if not self.context.client_info:
368-
raise OAuthTokenError("No client info available")
370+
raise OAuthTokenError("No client info available") # pragma: no cover
369371

370372
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
371-
token_url = str(self.context.oauth_metadata.token_endpoint)
373+
token_url = str(self.context.oauth_metadata.token_endpoint) # pragma: no cover
372374
else:
373375
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
374376
token_url = urljoin(auth_base_url, "/token")
@@ -383,14 +385,14 @@ async def _refresh_token(self) -> httpx.Request:
383385
if self.context.should_include_resource_param(self.context.protocol_version):
384386
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
385387

386-
if self.context.client_info.client_secret:
388+
if self.context.client_info.client_secret: # pragma: no branch
387389
refresh_data["client_secret"] = self.context.client_info.client_secret
388390

389391
return httpx.Request(
390392
"POST", token_url, data=refresh_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
391393
)
392394

393-
async def _handle_refresh_response(self, response: httpx.Response) -> bool:
395+
async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover
394396
"""Handle token refresh response. Returns True if successful."""
395397
if response.status_code != 200:
396398
logger.warning(f"Token refresh failed: {response.status_code}")
@@ -411,15 +413,15 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool:
411413
self.context.clear_tokens()
412414
return False
413415

414-
async def _initialize(self) -> None:
416+
async def _initialize(self) -> None: # pragma: no cover
415417
"""Load stored tokens and client info."""
416418
self.context.current_tokens = await self.context.storage.get_tokens()
417419
self.context.client_info = await self.context.storage.get_client_info()
418420
self._initialized = True
419421

420422
def _add_auth_header(self, request: httpx.Request) -> None:
421423
"""Add authorization header to request if we have valid tokens."""
422-
if self.context.current_tokens and self.context.current_tokens.access_token:
424+
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
423425
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
424426

425427
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
@@ -431,17 +433,17 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
431433
"""HTTPX auth flow integration."""
432434
async with self.context.lock:
433435
if not self._initialized:
434-
await self._initialize()
436+
await self._initialize() # pragma: no cover
435437

436438
# Capture protocol version from request headers
437439
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
438440

439441
if not self.context.is_token_valid() and self.context.can_refresh_token():
440442
# Try to refresh token
441-
refresh_request = await self._refresh_token()
442-
refresh_response = yield refresh_request
443+
refresh_request = await self._refresh_token() # pragma: no cover
444+
refresh_response = yield refresh_request # pragma: no cover
443445

444-
if not await self._handle_refresh_response(refresh_response):
446+
if not await self._handle_refresh_response(refresh_response): # pragma: no cover
445447
# Refresh failed, need full re-authentication
446448
self._initialized = False
447449

@@ -461,7 +463,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
461463
www_auth_resource_metadata_url, self.context.server_url
462464
)
463465
prm_discovery_success = False
464-
for url in prm_discovery_urls:
466+
for url in prm_discovery_urls: # pragma: no branch
465467
discovery_request = create_oauth_metadata_request(url)
466468

467469
discovery_response = yield discovery_request # sending request
@@ -472,18 +474,20 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
472474

473475
# saving the response metadata
474476
self.context.protected_resource_metadata = prm
475-
if prm.authorization_servers:
477+
if prm.authorization_servers: # pragma: no branch
476478
self.context.auth_server_url = str(prm.authorization_servers[0])
477479

478480
break
479481
else:
480482
logger.debug(f"Protected resource metadata discovery failed: {url}")
481483
if not prm_discovery_success:
482-
raise OAuthFlowError("Protected resource metadata discovery failed: no valid metadata found")
484+
raise OAuthFlowError(
485+
"Protected resource metadata discovery failed: no valid metadata found"
486+
) # pragma: no cover
483487

484488
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
485489
asm_discovery_urls = get_discovery_urls(self.context.auth_server_url or self.context.server_url)
486-
for url in asm_discovery_urls:
490+
for url in asm_discovery_urls: # pragma: no cover
487491
oauth_metadata_request = create_oauth_metadata_request(url)
488492
oauth_metadata_response = yield oauth_metadata_request
489493

@@ -518,7 +522,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
518522
# Step 5: Perform authorization and complete token exchange
519523
token_response = yield await self._perform_authorization()
520524
await self._handle_token_response(token_response)
521-
except Exception:
525+
except Exception: # pragma: no cover
522526
logger.exception("OAuth flow error")
523527
raise
524528

@@ -530,7 +534,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
530534
error = extract_field_from_www_auth(response, "error")
531535

532536
# Step 2: Check if we need to step-up authorization
533-
if error == "insufficient_scope":
537+
if error == "insufficient_scope": # pragma: no branch
534538
try:
535539
# Step 2a: Update the required scopes
536540
self.context.client_metadata.scope = get_client_metadata_scopes(
@@ -540,7 +544,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
540544
# Step 2b: Perform (re-)authorization and token exchange
541545
token_response = yield await self._perform_authorization()
542546
await self._handle_token_response(token_response)
543-
except Exception:
547+
except Exception: # pragma: no cover
544548
logger.exception("OAuth flow error")
545549
raise
546550

src/mcp/client/auth/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def extract_resource_metadata_from_www_auth(response: Response) -> str | None:
5959
Resource metadata URL if found in WWW-Authenticate header, None otherwise
6060
"""
6161
if not response or response.status_code != 401:
62-
return None
62+
return None # pragma: no cover
6363

6464
return extract_field_from_www_auth(response, "resource_metadata")
6565

@@ -120,7 +120,7 @@ def get_client_metadata_scopes(
120120
# Priority 2: PRM scopes_supported
121121
return " ".join(protected_resource_metadata.scopes_supported)
122122
elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None:
123-
return " ".join(authorization_server_metadata.scopes_supported)
123+
return " ".join(authorization_server_metadata.scopes_supported) # pragma: no cover
124124
else:
125125
# Priority 3: Omit scope parameter
126126
return None
@@ -170,7 +170,7 @@ async def handle_protected_resource_response(
170170
metadata = ProtectedResourceMetadata.model_validate_json(content)
171171
return metadata
172172

173-
except ValidationError:
173+
except ValidationError: # pragma: no cover
174174
# Invalid metadata - try next URL
175175
return None
176176
else:
@@ -184,7 +184,7 @@ async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuth
184184
content = await response.aread()
185185
asm = OAuthMetadata.model_validate_json(content)
186186
return True, asm
187-
except ValidationError:
187+
except ValidationError: # pragma: no cover
188188
return True, None
189189
elif response.status_code < 400 or response.status_code >= 500:
190190
return False, None # Non-4XX error, stop trying
@@ -222,7 +222,7 @@ async def handle_registration_response(response: Response) -> OAuthClientInforma
222222
return client_info
223223
# self.context.client_info = client_info
224224
# await self.context.storage.set_client_info(client_info)
225-
except ValidationError as e:
225+
except ValidationError as e: # pragma: no cover
226226
raise OAuthRegistrationError(f"Invalid registration response: {e}")
227227

228228

@@ -246,5 +246,5 @@ async def handle_token_response_scopes(
246246
content = await response.aread()
247247
token_response = OAuthToken.model_validate_json(content)
248248
return token_response
249-
except ValidationError as e:
249+
except ValidationError as e: # pragma: no cover
250250
raise OAuthTokenError(f"Invalid token response: {e}")

src/mcp/shared/auth_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def generate_pkce_parameters(verifier_length: int = 128) -> tuple[str, str]:
9090
ValueError: If verifier_length is not between 43 and 128
9191
"""
9292
if not 43 <= verifier_length <= 128:
93-
raise ValueError("verifier_length must be between 43 and 128 per RFC 7636")
93+
raise ValueError("verifier_length must be between 43 and 128 per RFC 7636") # pragma: no cover
9494

9595
# Generate code_verifier using unreserved characters per RFC 7636 Section 4.1
9696
# unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~"
@@ -116,6 +116,6 @@ def calculate_token_expiry(expires_in: int | str | None) -> float | None:
116116
Unix timestamp when token expires, or None if no expiry specified
117117
"""
118118
if expires_in is None:
119-
return None
119+
return None # pragma: no cover
120120
# Defensive: handle servers that return expires_in as string
121121
return time.time() + int(expires_in)

0 commit comments

Comments
 (0)