99from urllib .parse import urljoin
1010
1111import httpx
12+ from oauthlib .oauth2 import OAuth2Error
1213from oauthlib .oauth2 import WebApplicationClient
13- from oauthlib .oauth2 .rfc6749 .errors import CustomOAuth2Error
1414from social_core .backends .oauth import BaseOAuth2
15+ from social_core .exceptions import AuthException
1516from social_core .strategy import BaseStrategy
16- from starlette .exceptions import HTTPException
1717from starlette .requests import Request
1818from starlette .responses import RedirectResponse
1919
2020from .claims import Claims
2121from .client import OAuth2Client
22-
23-
24- class OAuth2LoginError (HTTPException ):
25- """Raised when any login-related error occurs."""
22+ from .exceptions import OAuth2AuthenticationError
23+ from .exceptions import OAuth2InvalidRequestError
2624
2725
2826class OAuth2Strategy (BaseStrategy ):
@@ -56,6 +54,7 @@ class OAuth2Core:
5654 _oauth_client : Optional [WebApplicationClient ] = None
5755 _authorization_endpoint : str = None
5856 _token_endpoint : str = None
57+ _state : str = None
5958
6059 def __init__ (self , client : OAuth2Client ) -> None :
6160 self .client_id = client .client_id
@@ -83,6 +82,8 @@ def authorization_url(self, request: Request) -> str:
8382 oauth2_query_params = dict (state = state , scope = self .scope , redirect_uri = redirect_uri )
8483 oauth2_query_params .update (request .query_params )
8584
85+ self ._state = oauth2_query_params .get ("state" )
86+
8687 return str (self ._oauth_client .prepare_request_uri (
8788 self ._authorization_endpoint ,
8889 ** oauth2_query_params ,
@@ -93,9 +94,11 @@ def authorization_redirect(self, request: Request) -> RedirectResponse:
9394
9495 async def token_data (self , request : Request , ** httpx_client_args ) -> dict :
9596 if not request .query_params .get ("code" ):
96- raise OAuth2LoginError (400 , "'code' parameter was not found in callback request" )
97+ raise OAuth2InvalidRequestError (400 , "'code' parameter was not found in callback request" )
9798 if not request .query_params .get ("state" ):
98- raise OAuth2LoginError (400 , "'state' parameter was not found in callback request" )
99+ raise OAuth2InvalidRequestError (400 , "'state' parameter was not found in callback request" )
100+ if request .query_params .get ("state" ) != self ._state :
101+ raise OAuth2InvalidRequestError (400 , "'state' parameter does not match" )
99102
100103 redirect_uri = self .get_redirect_uri (request )
101104 scheme = "http" if request .auth .http else "https"
@@ -112,12 +115,14 @@ async def token_data(self, request: Request, **httpx_client_args) -> dict:
112115 headers .update ({"Accept" : "application/json" })
113116 auth = httpx .BasicAuth (self .client_id , self .client_secret )
114117 async with httpx .AsyncClient (auth = auth , ** httpx_client_args ) as session :
115- response = await session .post (token_url , headers = headers , content = content )
116118 try :
119+ response = await session .post (token_url , headers = headers , content = content )
117120 self ._oauth_client .parse_request_body_response (json .dumps (response .json ()))
118121 return self .standardize (self .backend .user_data (self .access_token ))
119- except (CustomOAuth2Error , Exception ) as e :
120- raise OAuth2LoginError (400 , str (e ))
122+ except (OAuth2Error , httpx .HTTPError ) as e :
123+ raise OAuth2InvalidRequestError (400 , str (e ))
124+ except (AuthException , Exception ) as e :
125+ raise OAuth2AuthenticationError (401 , str (e ))
121126
122127 async def token_redirect (self , request : Request , ** kwargs ) -> RedirectResponse :
123128 access_token = request .auth .jwt_create (await self .token_data (request , ** kwargs ))
0 commit comments