diff --git a/src/mcp/client/auth/__init__.py b/src/mcp/client/auth/__init__.py index a5c4b7346..252dfd9e4 100644 --- a/src/mcp/client/auth/__init__.py +++ b/src/mcp/client/auth/__init__.py @@ -4,11 +4,9 @@ Implements authorization code flow with PKCE and automatic token refresh. """ +from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError from mcp.client.auth.oauth2 import ( OAuthClientProvider, - OAuthFlowError, - OAuthRegistrationError, - OAuthTokenError, PKCEParameters, TokenStorage, ) diff --git a/src/mcp/client/auth/exceptions.py b/src/mcp/client/auth/exceptions.py new file mode 100644 index 000000000..5ce8777b8 --- /dev/null +++ b/src/mcp/client/auth/exceptions.py @@ -0,0 +1,10 @@ +class OAuthFlowError(Exception): + """Base exception for OAuth flow errors.""" + + +class OAuthTokenError(OAuthFlowError): + """Raised when token operations fail.""" + + +class OAuthRegistrationError(OAuthFlowError): + """Raised when client registration fails.""" diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 634161b92..1463655ae 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -7,7 +7,6 @@ import base64 import hashlib import logging -import re import secrets import string import time @@ -20,6 +19,21 @@ import httpx from pydantic import BaseModel, Field, ValidationError +from mcp.client.auth import OAuthFlowError, OAuthTokenError +from mcp.client.auth.utils import ( + build_protected_resource_discovery_urls, + create_client_registration_request, + create_oauth_metadata_request, + extract_field_from_www_auth, + extract_resource_metadata_from_www_auth, + extract_scope_from_www_auth, + get_client_metadata_scopes, + get_discovery_urls, + handle_auth_metadata_response, + handle_protected_resource_response, + handle_registration_response, + handle_token_response_scopes, +) from mcp.client.streamable_http import MCP_PROTOCOL_VERSION from mcp.shared.auth import ( OAuthClientInformationFull, @@ -28,24 +42,15 @@ OAuthToken, ProtectedResourceMetadata, ) -from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url -from mcp.types import LATEST_PROTOCOL_VERSION +from mcp.shared.auth_utils import ( + calculate_token_expiry, + check_resource_allowed, + resource_url_from_server_url, +) logger = logging.getLogger(__name__) -class OAuthFlowError(Exception): - """Base exception for OAuth flow errors.""" - - -class OAuthTokenError(OAuthFlowError): - """Raised when token operations fail.""" - - -class OAuthRegistrationError(OAuthFlowError): - """Raised when client registration fails.""" - - class PKCEParameters(BaseModel): """PKCE (Proof Key for Code Exchange) parameters.""" @@ -114,11 +119,8 @@ def get_authorization_base_url(self, server_url: str) -> str: return f"{parsed.scheme}://{parsed.netloc}" def update_token_expiry(self, token: OAuthToken) -> None: - """Update token expiry time.""" - if token.expires_in: - self.token_expiry_time = time.time() + token.expires_in - else: # pragma: no cover - self.token_expiry_time = None + """Update token expiry time using shared util function.""" + self.token_expiry_time = calculate_token_expiry(token.expires_in) def is_token_valid(self) -> bool: """Check if current token is valid.""" @@ -200,85 +202,6 @@ def __init__( ) self._initialized = False - def _build_protected_resource_discovery_urls(self, init_response: httpx.Response) -> list[str]: - """ - Build ordered list of URLs to try for protected resource metadata discovery. - - Per SEP-985, the client MUST: - 1. Try resource_metadata from WWW-Authenticate header (if present) - 2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path} - 3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource - - Args: - init_response: The initial 401 response from the server - - Returns: - Ordered list of URLs to try for discovery - """ - urls: list[str] = [] - - # Priority 1: WWW-Authenticate header with resource_metadata parameter - www_auth_url = self._extract_resource_metadata_from_www_auth(init_response) - if www_auth_url: - urls.append(www_auth_url) - - # Priority 2-3: Well-known URIs (RFC 9728) - parsed = urlparse(self.context.server_url) - base_url = f"{parsed.scheme}://{parsed.netloc}" - - # Priority 2: Path-based well-known URI (if server has a path component) - if parsed.path and parsed.path != "/": - path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}") - urls.append(path_based_url) - - # Priority 3: Root-based well-known URI - root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource") - urls.append(root_based_url) - - return urls - - def _extract_field_from_www_auth(self, init_response: httpx.Response, field_name: str) -> str | None: - """ - Extract field from WWW-Authenticate header. - - Returns: - Field value if found in WWW-Authenticate header, None otherwise - """ - www_auth_header = init_response.headers.get("WWW-Authenticate") - if not www_auth_header: - return None - - # Pattern matches: field_name="value" or field_name=value (unquoted) - pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))' - match = re.search(pattern, www_auth_header) - - if match: - # Return quoted value if present, otherwise unquoted value - return match.group(1) or match.group(2) - - return None - - def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None: - """ - Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. - - Returns: - Resource metadata URL if found in WWW-Authenticate header, None otherwise - """ - if not init_response or init_response.status_code != 401: # pragma: no cover - return None - - return self._extract_field_from_www_auth(init_response, "resource_metadata") - - def _extract_scope_from_www_auth(self, init_response: httpx.Response) -> str | None: - """ - Extract scope parameter from WWW-Authenticate header as per RFC6750. - - Returns: - Scope string if found in WWW-Authenticate header, None otherwise - """ - return self._extract_field_from_www_auth(init_response, "scope") - async def _handle_protected_resource_response(self, response: httpx.Response) -> bool: """ Handle protected resource metadata discovery response. @@ -301,61 +224,15 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> # Invalid metadata - try next URL logger.warning(f"Invalid protected resource metadata at {response.request.url}") return False - elif response.status_code == 404: + elif response.status_code == 404: # pragma: no cover # Not found - try next URL in fallback chain logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL") return False - else: # pragma: no cover - # Other error - fail immediately - raise OAuthFlowError(f"Protected Resource Metadata request failed: {response.status_code}") - - def _select_scopes(self, init_response: httpx.Response) -> None: - """Select scopes as outlined in the 'Scope Selection Strategy in the MCP spec.""" - # Per MCP spec, scope selection priority order: - # 1. Use scope from WWW-Authenticate header (if provided) - # 2. Use all scopes from PRM scopes_supported (if available) - # 3. Omit scope parameter if neither is available - # - www_authenticate_scope = self._extract_scope_from_www_auth(init_response) - if www_authenticate_scope is not None: - # Priority 1: WWW-Authenticate header scope - self.context.client_metadata.scope = www_authenticate_scope - elif ( - self.context.protected_resource_metadata is not None - and self.context.protected_resource_metadata.scopes_supported is not None - ): - # Priority 2: PRM scopes_supported - self.context.client_metadata.scope = " ".join(self.context.protected_resource_metadata.scopes_supported) else: - # Priority 3: Omit scope parameter - self.context.client_metadata.scope = None - - def _get_discovery_urls(self) -> list[str]: - """Generate ordered list of (url, type) tuples for discovery attempts.""" - urls: list[str] = [] - auth_server_url = self.context.auth_server_url or self.context.server_url - parsed = urlparse(auth_server_url) - base_url = f"{parsed.scheme}://{parsed.netloc}" - - # RFC 8414: Path-aware OAuth discovery - if parsed.path and parsed.path != "/": - oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}" - urls.append(urljoin(base_url, oauth_path)) - - # OAuth root fallback - urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server")) - - # RFC 8414 section 5: Path-aware OIDC discovery - # See https://www.rfc-editor.org/rfc/rfc8414.html#section-5 - if parsed.path and parsed.path != "/": - oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}" - urls.append(urljoin(base_url, oidc_path)) - - # OIDC 1.0 fallback (appends to full URL per OIDC spec) - oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration" - urls.append(oidc_fallback) - - return urls + # Other error - fail immediately + raise OAuthFlowError( + f"Protected Resource Metadata request failed: {response.status_code}" + ) # pragma: no cover async def _register_client(self) -> httpx.Request | None: """Build registration request or skip if already registered.""" @@ -363,7 +240,7 @@ async def _register_client(self) -> httpx.Request | None: return None if self.context.oauth_metadata and self.context.oauth_metadata.registration_endpoint: - registration_url = str(self.context.oauth_metadata.registration_endpoint) + registration_url = str(self.context.oauth_metadata.registration_endpoint) # pragma: no cover else: auth_base_url = self.context.get_authorization_base_url(self.context.server_url) registration_url = urljoin(auth_base_url, "/register") @@ -374,20 +251,6 @@ async def _register_client(self) -> httpx.Request | None: "POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"} ) - async def _handle_registration_response(self, response: httpx.Response) -> None: - """Handle registration response.""" - if response.status_code not in (200, 201): - await response.aread() - raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}") - - try: - content = await response.aread() - client_info = OAuthClientInformationFull.model_validate_json(content) - self.context.client_info = client_info - await self.context.storage.set_client_info(client_info) - except ValidationError as e: # pragma: no cover - raise OAuthRegistrationError(f"Invalid registration response: {e}") - async def _perform_authorization(self) -> httpx.Request: """Perform the authorization flow.""" auth_code, code_verifier = await self._perform_authorization_code_grant() @@ -427,9 +290,9 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: # Only include resource param if conditions are met if self.context.should_include_resource_param(self.context.protocol_version): - auth_params["resource"] = self.context.get_resource_url() # RFC 8707 # pragma: no cover + auth_params["resource"] = self.context.get_resource_url() # RFC 8707 # pragma: no cover - if self.context.client_metadata.scope: # pragma: no cover + if self.context.client_metadata.scope: # pragma: no branch auth_params["scope"] = self.context.client_metadata.scope authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" @@ -489,30 +352,18 @@ async def _exchange_token_authorization_code( async def _handle_token_response(self, response: httpx.Response) -> None: """Handle token exchange response.""" - if response.status_code != 200: # pragma: no cover - body = await response.aread() - body = body.decode("utf-8") - raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body}") - - try: - content = await response.aread() - token_response = OAuthToken.model_validate_json(content) + if response.status_code != 200: + body = await response.aread() # pragma: no cover + body_text = body.decode("utf-8") # pragma: no cover + raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") # pragma: no cover - # Validate scopes - if token_response.scope and self.context.client_metadata.scope: - requested_scopes = set(self.context.client_metadata.scope.split()) - returned_scopes = set(token_response.scope.split()) - unauthorized_scopes = returned_scopes - requested_scopes - if unauthorized_scopes: - raise OAuthTokenError( - f"Server granted unauthorized scopes: {unauthorized_scopes}" - ) # pragma: no cover + # Parse and validate response with scope validation + token_response = await handle_token_response_scopes(response) - self.context.current_tokens = token_response - self.context.update_token_expiry(token_response) - await self.context.storage.set_tokens(token_response) - except ValidationError as e: # pragma: no cover - raise OAuthTokenError(f"Invalid token response: {e}") + # Store tokens in context + self.context.current_tokens = token_response + self.context.update_token_expiry(token_response) + await self.context.storage.set_tokens(token_response) async def _refresh_token(self) -> httpx.Request: """Build token refresh request.""" @@ -577,9 +428,6 @@ def _add_auth_header(self, request: httpx.Request) -> None: if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" - def _create_oauth_metadata_request(self, url: str) -> httpx.Request: - return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) - async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: content = await response.aread() metadata = OAuthMetadata.model_validate_json(content) @@ -594,12 +442,12 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Capture protocol version from request headers self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) - if not self.context.is_token_valid() and self.context.can_refresh_token(): # pragma: no cover + if not self.context.is_token_valid() and self.context.can_refresh_token(): # Try to refresh token - refresh_request = await self._refresh_token() - refresh_response = yield refresh_request + refresh_request = await self._refresh_token() # pragma: no cover + refresh_response = yield refresh_request # pragma: no cover - if not await self._handle_refresh_response(refresh_response): + if not await self._handle_refresh_response(refresh_response): # pragma: no cover # Refresh failed, need full re-authentication self._initialized = False @@ -612,46 +460,68 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Perform full OAuth flow try: # OAuth flow must be inline due to generator constraints + www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response) + # Step 1: Discover protected resource metadata (SEP-985 with fallback support) - discovery_urls = self._build_protected_resource_discovery_urls(response) - discovery_success = False - for url in discovery_urls: # pragma: no cover - discovery_request = httpx.Request( - "GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} - ) - discovery_response = yield discovery_request - discovery_success = await self._handle_protected_resource_response(discovery_response) - if discovery_success: - break + prm_discovery_urls = build_protected_resource_discovery_urls( + www_auth_resource_metadata_url, self.context.server_url + ) + prm_discovery_success = False + for url in prm_discovery_urls: # pragma: no branch + discovery_request = create_oauth_metadata_request(url) - if not discovery_success: + discovery_response = yield discovery_request # sending request + + prm = await handle_protected_resource_response(discovery_response) + if prm: + prm_discovery_success = True + + # saving the response metadata + self.context.protected_resource_metadata = prm + if prm.authorization_servers: # pragma: no branch + self.context.auth_server_url = str(prm.authorization_servers[0]) + + break + else: + logger.debug(f"Protected resource metadata discovery failed: {url}") + if not prm_discovery_success: raise OAuthFlowError( "Protected resource metadata discovery failed: no valid metadata found" ) # pragma: no cover - # Step 2: Apply scope selection strategy - self._select_scopes(response) - - # Step 3: Discover OAuth metadata (with fallback for legacy servers) - discovery_urls = self._get_discovery_urls() - for url in discovery_urls: # pragma: no branch - oauth_metadata_request = self._create_oauth_metadata_request(url) + # Step 2: Discover OAuth metadata (with fallback for legacy servers) + asm_discovery_urls = get_discovery_urls(self.context.auth_server_url or self.context.server_url) + for url in asm_discovery_urls: # pragma: no cover + oauth_metadata_request = create_oauth_metadata_request(url) oauth_metadata_response = yield oauth_metadata_request - if oauth_metadata_response.status_code == 200: - try: - await self._handle_oauth_metadata_response(oauth_metadata_response) - break - except ValidationError: # pragma: no cover - continue - elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500: - break # Non-4XX error, stop trying + ok, asm = await handle_auth_metadata_response(oauth_metadata_response) + if not ok: + break + if ok and asm: + self.context.oauth_metadata = asm + break + else: + logger.debug(f"OAuth metadata discovery failed: {url}") + + # Step 3: Apply scope selection strategy + self.context.client_metadata.scope = get_client_metadata_scopes( + www_auth_resource_metadata_url, + self.context.protected_resource_metadata, + self.context.oauth_metadata, + ) # Step 4: Register client if needed - registration_request = await self._register_client() - if registration_request: + registration_request = create_client_registration_request( + self.context.oauth_metadata, + self.context.client_metadata, + self.context.get_authorization_base_url(self.context.server_url), + ) + if not self.context.client_info: registration_response = yield registration_request - await self._handle_registration_response(registration_response) + client_information = await handle_registration_response(registration_response) + self.context.client_info = client_information + await self.context.storage.set_client_info(client_information) # Step 5: Perform authorization and complete token exchange token_response = yield await self._perform_authorization() @@ -665,13 +535,15 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. yield request elif response.status_code == 403: # Step 1: Extract error field from WWW-Authenticate header - error = self._extract_field_from_www_auth(response, "error") + error = extract_field_from_www_auth(response, "error") # Step 2: Check if we need to step-up authorization if error == "insufficient_scope": # pragma: no branch try: # Step 2a: Update the required scopes - self._select_scopes(response) + self.context.client_metadata.scope = get_client_metadata_scopes( + extract_scope_from_www_auth(response), self.context.protected_resource_metadata + ) # Step 2b: Perform (re-)authorization and token exchange token_response = yield await self._perform_authorization() diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py new file mode 100644 index 000000000..1774c5ff5 --- /dev/null +++ b/src/mcp/client/auth/utils.py @@ -0,0 +1,250 @@ +import logging +import re +from urllib.parse import urljoin, urlparse + +from httpx import Request, Response +from pydantic import ValidationError + +from mcp.client.auth import OAuthRegistrationError, OAuthTokenError +from mcp.client.streamable_http import MCP_PROTOCOL_VERSION +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthMetadata, + OAuthToken, + ProtectedResourceMetadata, +) +from mcp.types import LATEST_PROTOCOL_VERSION + +logger = logging.getLogger(__name__) + + +def extract_field_from_www_auth(response: Response, field_name: str) -> str | None: + """ + Extract field from WWW-Authenticate header. + + Returns: + Field value if found in WWW-Authenticate header, None otherwise + """ + www_auth_header = response.headers.get("WWW-Authenticate") + if not www_auth_header: + return None + + # Pattern matches: field_name="value" or field_name=value (unquoted) + pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))' + match = re.search(pattern, www_auth_header) + + if match: + # Return quoted value if present, otherwise unquoted value + return match.group(1) or match.group(2) + + return None + + +def extract_scope_from_www_auth(response: Response) -> str | None: + """ + Extract scope parameter from WWW-Authenticate header as per RFC6750. + + Returns: + Scope string if found in WWW-Authenticate header, None otherwise + """ + return extract_field_from_www_auth(response, "scope") + + +def extract_resource_metadata_from_www_auth(response: Response) -> str | None: + """ + Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. + + Returns: + Resource metadata URL if found in WWW-Authenticate header, None otherwise + """ + if not response or response.status_code != 401: + return None # pragma: no cover + + return extract_field_from_www_auth(response, "resource_metadata") + + +def build_protected_resource_discovery_urls(www_auth_url: str | None, server_url: str) -> list[str]: + """ + Build ordered list of URLs to try for protected resource metadata discovery. + + Per SEP-985, the client MUST: + 1. Try resource_metadata from WWW-Authenticate header (if present) + 2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path} + 3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource + + Args: + www_auth_url: optional resource_metadata url extracted from the WWW-Authenticate header + server_url: server url + + Returns: + Ordered list of URLs to try for discovery + """ + urls: list[str] = [] + + # Priority 1: WWW-Authenticate header with resource_metadata parameter + if www_auth_url: + urls.append(www_auth_url) + + # Priority 2-3: Well-known URIs (RFC 9728) + parsed = urlparse(server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + + # Priority 2: Path-based well-known URI (if server has a path component) + if parsed.path and parsed.path != "/": + path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}") + urls.append(path_based_url) + + # Priority 3: Root-based well-known URI + root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource") + urls.append(root_based_url) + + return urls + + +def get_client_metadata_scopes( + www_authenticate_scope: str | None, + protected_resource_metadata: ProtectedResourceMetadata | None, + authorization_server_metadata: OAuthMetadata | None = None, +) -> str | None: + """Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec.""" + # Per MCP spec, scope selection priority order: + # 1. Use scope from WWW-Authenticate header (if provided) + # 2. Use all scopes from PRM scopes_supported (if available) + # 3. Omit scope parameter if neither is available + + if www_authenticate_scope is not None: + # Priority 1: WWW-Authenticate header scope + return www_authenticate_scope + elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None: + # Priority 2: PRM scopes_supported + return " ".join(protected_resource_metadata.scopes_supported) + elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None: + return " ".join(authorization_server_metadata.scopes_supported) # pragma: no cover + else: + # Priority 3: Omit scope parameter + return None + + +def get_discovery_urls(auth_server_url: str) -> list[str]: + """Generate ordered list of (url, type) tuples for discovery attempts.""" + urls: list[str] = [] + parsed = urlparse(auth_server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + + # RFC 8414: Path-aware OAuth discovery + if parsed.path and parsed.path != "/": + oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}" + urls.append(urljoin(base_url, oauth_path)) + + # OAuth root fallback + urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server")) + + # RFC 8414 section 5: Path-aware OIDC discovery + # See https://www.rfc-editor.org/rfc/rfc8414.html#section-5 + if parsed.path and parsed.path != "/": + oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}" + urls.append(urljoin(base_url, oidc_path)) + + # OIDC 1.0 fallback (appends to full URL per OIDC spec) + oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration" + urls.append(oidc_fallback) + + return urls + + +async def handle_protected_resource_response( + response: Response, +) -> ProtectedResourceMetadata | None: + """ + Handle protected resource metadata discovery response. + + Per SEP-985, supports fallback when discovery fails at one URL. + + Returns: + True if metadata was successfully discovered, False if we should try next URL + """ + if response.status_code == 200: + try: + content = await response.aread() + metadata = ProtectedResourceMetadata.model_validate_json(content) + return metadata + + except ValidationError: # pragma: no cover + # Invalid metadata - try next URL + return None + else: + # Not found - try next URL in fallback chain + return None + + +async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuthMetadata | None]: + if response.status_code == 200: + try: + content = await response.aread() + asm = OAuthMetadata.model_validate_json(content) + return True, asm + except ValidationError: # pragma: no cover + return True, None + elif response.status_code < 400 or response.status_code >= 500: + return False, None # Non-4XX error, stop trying + return True, None + + +def create_oauth_metadata_request(url: str) -> Request: + return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) + + +def create_client_registration_request( + auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str +) -> Request: + """Build registration request or skip if already registered.""" + + if auth_server_metadata and auth_server_metadata.registration_endpoint: + registration_url = str(auth_server_metadata.registration_endpoint) + else: + registration_url = urljoin(auth_base_url, "/register") + + registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) + + return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"}) + + +async def handle_registration_response(response: Response) -> OAuthClientInformationFull: + """Handle registration response.""" + if response.status_code not in (200, 201): + await response.aread() + raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}") + + try: + content = await response.aread() + client_info = OAuthClientInformationFull.model_validate_json(content) + return client_info + # self.context.client_info = client_info + # await self.context.storage.set_client_info(client_info) + except ValidationError as e: # pragma: no cover + raise OAuthRegistrationError(f"Invalid registration response: {e}") + + +async def handle_token_response_scopes( + response: Response, +) -> OAuthToken: + """Parse and validate token response with optional scope validation. + + Parses token response JSON. Callers should check response.status_code before calling. + + Args: + response: HTTP response from token endpoint (status already checked by caller) + + Returns: + Validated OAuthToken model + + Raises: + OAuthTokenError: If response JSON is invalid + """ + try: + content = await response.aread() + token_response = OAuthToken.model_validate_json(content) + return token_response + except ValidationError as e: # pragma: no cover + raise OAuthTokenError(f"Invalid token response: {e}") diff --git a/src/mcp/shared/auth_utils.py b/src/mcp/shared/auth_utils.py index 6d6300c9c..8f3c542f2 100644 --- a/src/mcp/shared/auth_utils.py +++ b/src/mcp/shared/auth_utils.py @@ -1,5 +1,6 @@ -"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707).""" +"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707) and PKCE (RFC 7636).""" +import time from urllib.parse import urlparse, urlsplit, urlunsplit from pydantic import AnyUrl, HttpUrl @@ -67,3 +68,18 @@ def check_resource_allowed(requested_resource: str, configured_resource: str) -> configured_path += "/" return requested_path.startswith(configured_path) + + +def calculate_token_expiry(expires_in: int | str | None) -> float | None: + """Calculate token expiry timestamp from expires_in seconds. + + Args: + expires_in: Seconds until token expiration (may be string from some servers) + + Returns: + Unix timestamp when token expires, or None if no expiry specified + """ + if expires_in is None: + return None # pragma: no cover + # Defensive: handle servers that return expires_in as string + return time.time() + int(expires_in) diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 3feedf9e9..46a552e58 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -11,6 +11,16 @@ from pydantic import AnyHttpUrl, AnyUrl from mcp.client.auth import OAuthClientProvider, PKCEParameters +from mcp.client.auth.utils import ( + build_protected_resource_discovery_urls, + create_oauth_metadata_request, + extract_field_from_www_auth, + extract_resource_metadata_from_www_auth, + extract_scope_from_www_auth, + get_client_metadata_scopes, + get_discovery_urls, + handle_registration_response, +) from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken, ProtectedResourceMetadata @@ -265,7 +275,9 @@ async def callback_handler() -> tuple[str, str | None]: status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com") ) - urls = provider._build_protected_resource_discovery_urls(init_response) + urls = build_protected_resource_discovery_urls( + extract_resource_metadata_from_www_auth(init_response), provider.context.server_url + ) assert len(urls) == 1 assert urls[0] == "https://api.example.com/.well-known/oauth-protected-resource" @@ -274,7 +286,9 @@ async def callback_handler() -> tuple[str, str | None]: 'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"' ) - urls = provider._build_protected_resource_discovery_urls(init_response) + urls = build_protected_resource_discovery_urls( + extract_resource_metadata_from_www_auth(init_response), provider.context.server_url + ) assert len(urls) == 2 assert urls[0] == "https://prm.example.com/.well-known/oauth-protected-resource/path" assert urls[1] == "https://api.example.com/.well-known/oauth-protected-resource" @@ -282,7 +296,7 @@ async def callback_handler() -> tuple[str, str | None]: @pytest.mark.anyio def test_create_oauth_metadata_request(self, oauth_provider: OAuthClientProvider): """Test OAuth metadata discovery request building.""" - request = oauth_provider._create_oauth_metadata_request("https://example.com") + request = create_oauth_metadata_request("https://example.com") # Ensure correct method and headers, and that the URL is unmodified assert request.method == "GET" @@ -296,7 +310,7 @@ class TestOAuthFallback: @pytest.mark.anyio async def test_oauth_discovery_fallback_order(self, oauth_provider: OAuthClientProvider): """Test fallback URL construction order.""" - discovery_urls = oauth_provider._get_discovery_urls() + discovery_urls = get_discovery_urls(oauth_provider.context.auth_server_url or oauth_provider.context.server_url) assert discovery_urls == [ "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp", @@ -450,10 +464,13 @@ async def test_prioritize_www_auth_scope_over_prm( await oauth_provider._handle_protected_resource_response(prm_metadata_response) # Process the scope selection with WWW-Authenticate header - oauth_provider._select_scopes(init_response_with_www_auth_scope) + scopes = get_client_metadata_scopes( + extract_scope_from_www_auth(init_response_with_www_auth_scope), + oauth_provider.context.protected_resource_metadata, + ) # Verify that WWW-Authenticate scope is used (not PRM scopes) - assert oauth_provider.context.client_metadata.scope == "special:scope from:www-authenticate" + assert scopes == "special:scope from:www-authenticate" @pytest.mark.anyio async def test_prioritize_prm_scopes_when_no_www_auth_scope( @@ -467,10 +484,13 @@ async def test_prioritize_prm_scopes_when_no_www_auth_scope( await oauth_provider._handle_protected_resource_response(prm_metadata_response) # Process the scope selection without WWW-Authenticate scope - oauth_provider._select_scopes(init_response_without_www_auth_scope) + scopes = get_client_metadata_scopes( + extract_scope_from_www_auth(init_response_without_www_auth_scope), + oauth_provider.context.protected_resource_metadata, + ) # Verify that PRM scopes are used - assert oauth_provider.context.client_metadata.scope == "resource:read resource:write" + assert scopes == "resource:read resource:write" @pytest.mark.anyio async def test_omit_scope_when_no_prm_scopes_or_www_auth( @@ -484,10 +504,12 @@ async def test_omit_scope_when_no_prm_scopes_or_www_auth( await oauth_provider._handle_protected_resource_response(prm_metadata_without_scopes_response) # Process the scope selection without WWW-Authenticate scope - oauth_provider._select_scopes(init_response_without_www_auth_scope) - + scopes = get_client_metadata_scopes( + extract_scope_from_www_auth(init_response_without_www_auth_scope), + oauth_provider.context.protected_resource_metadata, + ) # Verify that scope is omitted - assert oauth_provider.context.client_metadata.scope is None + assert scopes is None @pytest.mark.anyio async def test_register_client_request(self, oauth_provider: OAuthClientProvider): @@ -647,7 +669,7 @@ class TestRegistrationResponse: """Test client registration response handling.""" @pytest.mark.anyio - async def test_handle_registration_response_reads_before_accessing_text(self, oauth_provider: OAuthClientProvider): + async def test_handle_registration_response_reads_before_accessing_text(self): """Test that response.aread() is called before accessing response.text.""" # Track if aread() was called @@ -663,15 +685,15 @@ async def aread(self): @property def text(self): - if not self._aread_called: # pragma: no cover - raise RuntimeError("Response.text accessed before response.aread()") + if not self._aread_called: + raise RuntimeError("Response.text accessed before response.aread()") # pragma: no cover return self._text mock_response = MockResponse() # This should call aread() before accessing text with pytest.raises(Exception) as exc_info: - await oauth_provider._handle_registration_response(mock_response) + await handle_registration_response(mock_response) # Verify aread() was called assert mock_response._aread_called @@ -846,14 +868,14 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth( # In the buggy version, this would yield the request AGAIN unconditionally # In the fixed version, this should end the generator try: - await auth_flow.asend(response) # extra request # pragma: no cover + await auth_flow.asend(response) # extra request request_yields += 1 # pragma: no cover - # If we reach here, the bug is present # pragma: no cover - pytest.fail( # pragma: no cover + # If we reach here, the bug is present + pytest.fail( f"Unnecessary retry detected! Request was yielded {request_yields} times. " f"This indicates the retry logic bug that caused 2x performance degradation. " f"The request should only be yielded once for successful responses." - ) + ) # pragma: no cover except StopAsyncIteration: # This is the expected behavior - no unnecessary retry pass @@ -1062,7 +1084,9 @@ async def callback_handler() -> tuple[str, str | None]: ) # Build discovery URLs - discovery_urls = provider._build_protected_resource_discovery_urls(init_response) + discovery_urls = build_protected_resource_discovery_urls( + extract_resource_metadata_from_www_auth(init_response), provider.context.server_url + ) # Should have path-based URL first, then root-based URL assert len(discovery_urls) == 2 @@ -1167,7 +1191,7 @@ async def callback_handler() -> tuple[str, str | None]: final_response = httpx.Response(200, request=final_request) try: await auth_flow.asend(final_response) - except StopAsyncIteration: + except StopAsyncIteration: # pragma: no cover pass @pytest.mark.anyio @@ -1200,7 +1224,9 @@ async def callback_handler() -> tuple[str, str | None]: ) # Build discovery URLs - discovery_urls = provider._build_protected_resource_discovery_urls(init_response) + discovery_urls = build_protected_resource_discovery_urls( + extract_resource_metadata_from_www_auth(init_response), provider.context.server_url + ) # Should have WWW-Authenticate URL first, then fallback URLs assert len(discovery_urls) == 3 @@ -1268,27 +1294,13 @@ def test_extract_field_from_www_auth_valid_cases( ): """Test extraction of various fields from valid WWW-Authenticate headers.""" - async def redirect_handler(url: str) -> None: - pass # pragma: no cover - - async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" # pragma: no cover - - provider = OAuthClientProvider( - server_url="https://api.example.com/v1/mcp", - client_metadata=client_metadata, - storage=mock_storage, - redirect_handler=redirect_handler, - callback_handler=callback_handler, - ) - init_response = httpx.Response( status_code=401, headers={"WWW-Authenticate": www_auth_header}, request=httpx.Request("GET", "https://api.example.com/test"), ) - result = provider._extract_field_from_www_auth(init_response, field_name) + result = extract_field_from_www_auth(init_response, field_name) assert result == expected_value @pytest.mark.parametrize( @@ -1316,24 +1328,10 @@ def test_extract_field_from_www_auth_invalid_cases( ): """Test extraction returns None for invalid cases.""" - async def redirect_handler(url: str) -> None: - pass # pragma: no cover - - async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" # pragma: no cover - - provider = OAuthClientProvider( - server_url="https://api.example.com/v1/mcp", - client_metadata=client_metadata, - storage=mock_storage, - redirect_handler=redirect_handler, - callback_handler=callback_handler, - ) - headers = {"WWW-Authenticate": www_auth_header} if www_auth_header is not None else {} init_response = httpx.Response( status_code=401, headers=headers, request=httpx.Request("GET", "https://api.example.com/test") ) - result = provider._extract_field_from_www_auth(init_response, field_name) + result = extract_field_from_www_auth(init_response, field_name) assert result is None, f"Should return None for {description}"