Decoupled code from DefaultSigningKey (#16743)
Decoupled code from `DefaultSigningKey`. Makes testing a little bit easier and is cleaner.
This commit is contained in:
		
							parent
							
								
									cd8db3a83d
								
							
						
					
					
						commit
						88abb0dc8a
					
				|  | @ -115,7 +115,7 @@ type AccessTokenResponse struct { | |||
| 	IDToken      string    `json:"id_token,omitempty"` | ||||
| } | ||||
| 
 | ||||
| func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) { | ||||
| func newAccessTokenResponse(grant *models.OAuth2Grant, serverKey, clientKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) { | ||||
| 	if setting.OAuth2.InvalidateRefreshTokens { | ||||
| 		if err := grant.IncreaseCounter(); err != nil { | ||||
| 			return nil, &AccessTokenError{ | ||||
|  | @ -133,7 +133,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign | |||
| 			ExpiresAt: expirationDate.AsTime().Unix(), | ||||
| 		}, | ||||
| 	} | ||||
| 	signedAccessToken, err := accessToken.SignToken() | ||||
| 	signedAccessToken, err := accessToken.SignToken(serverKey) | ||||
| 	if err != nil { | ||||
| 		return nil, &AccessTokenError{ | ||||
| 			ErrorCode:        AccessTokenErrorCodeInvalidRequest, | ||||
|  | @ -151,7 +151,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign | |||
| 			ExpiresAt: refreshExpirationDate, | ||||
| 		}, | ||||
| 	} | ||||
| 	signedRefreshToken, err := refreshToken.SignToken() | ||||
| 	signedRefreshToken, err := refreshToken.SignToken(serverKey) | ||||
| 	if err != nil { | ||||
| 		return nil, &AccessTokenError{ | ||||
| 			ErrorCode:        AccessTokenErrorCodeInvalidRequest, | ||||
|  | @ -207,7 +207,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign | |||
| 			idToken.EmailVerified = user.IsActive | ||||
| 		} | ||||
| 
 | ||||
| 		signedIDToken, err = idToken.SignToken(signingKey) | ||||
| 		signedIDToken, err = idToken.SignToken(clientKey) | ||||
| 		if err != nil { | ||||
| 			return nil, &AccessTokenError{ | ||||
| 				ErrorCode:        AccessTokenErrorCodeInvalidRequest, | ||||
|  | @ -265,7 +265,7 @@ func IntrospectOAuth(ctx *context.Context) { | |||
| 	} | ||||
| 
 | ||||
| 	form := web.GetForm(ctx).(*forms.IntrospectTokenForm) | ||||
| 	token, err := oauth2.ParseToken(form.Token) | ||||
| 	token, err := oauth2.ParseToken(form.Token, oauth2.DefaultSigningKey) | ||||
| 	if err == nil { | ||||
| 		if token.Valid() == nil { | ||||
| 			grant, err := models.GetOAuth2GrantByID(token.GrantID) | ||||
|  | @ -544,9 +544,11 @@ func AccessTokenOAuth(ctx *context.Context) { | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	signingKey := oauth2.DefaultSigningKey | ||||
| 	if signingKey.IsSymmetric() { | ||||
| 		clientKey, err := oauth2.CreateJWTSigningKey(signingKey.SigningMethod().Alg(), []byte(form.ClientSecret)) | ||||
| 	serverKey := oauth2.DefaultSigningKey | ||||
| 	clientKey := serverKey | ||||
| 	if serverKey.IsSymmetric() { | ||||
| 		var err error | ||||
| 		clientKey, err = oauth2.CreateJWTSigningKey(serverKey.SigningMethod().Alg(), []byte(form.ClientSecret)) | ||||
| 		if err != nil { | ||||
| 			handleAccessTokenError(ctx, AccessTokenError{ | ||||
| 				ErrorCode:        AccessTokenErrorCodeInvalidRequest, | ||||
|  | @ -554,14 +556,13 @@ func AccessTokenOAuth(ctx *context.Context) { | |||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
| 		signingKey = clientKey | ||||
| 	} | ||||
| 
 | ||||
| 	switch form.GrantType { | ||||
| 	case "refresh_token": | ||||
| 		handleRefreshToken(ctx, form, signingKey) | ||||
| 		handleRefreshToken(ctx, form, serverKey, clientKey) | ||||
| 	case "authorization_code": | ||||
| 		handleAuthorizationCode(ctx, form, signingKey) | ||||
| 		handleAuthorizationCode(ctx, form, serverKey, clientKey) | ||||
| 	default: | ||||
| 		handleAccessTokenError(ctx, AccessTokenError{ | ||||
| 			ErrorCode:        AccessTokenErrorCodeUnsupportedGrantType, | ||||
|  | @ -570,8 +571,8 @@ func AccessTokenOAuth(ctx *context.Context) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signingKey oauth2.JWTSigningKey) { | ||||
| 	token, err := oauth2.ParseToken(form.RefreshToken) | ||||
| func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, serverKey, clientKey oauth2.JWTSigningKey) { | ||||
| 	token, err := oauth2.ParseToken(form.RefreshToken, serverKey) | ||||
| 	if err != nil { | ||||
| 		handleAccessTokenError(ctx, AccessTokenError{ | ||||
| 			ErrorCode:        AccessTokenErrorCodeUnauthorizedClient, | ||||
|  | @ -598,7 +599,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signin | |||
| 		log.Warn("A client tried to use a refresh token for grant_id = %d was used twice!", grant.ID) | ||||
| 		return | ||||
| 	} | ||||
| 	accessToken, tokenErr := newAccessTokenResponse(grant, signingKey) | ||||
| 	accessToken, tokenErr := newAccessTokenResponse(grant, serverKey, clientKey) | ||||
| 	if tokenErr != nil { | ||||
| 		handleAccessTokenError(ctx, *tokenErr) | ||||
| 		return | ||||
|  | @ -606,7 +607,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signin | |||
| 	ctx.JSON(http.StatusOK, accessToken) | ||||
| } | ||||
| 
 | ||||
| func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, signingKey oauth2.JWTSigningKey) { | ||||
| func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, serverKey, clientKey oauth2.JWTSigningKey) { | ||||
| 	app, err := models.GetOAuth2ApplicationByClientID(form.ClientID) | ||||
| 	if err != nil { | ||||
| 		handleAccessTokenError(ctx, AccessTokenError{ | ||||
|  | @ -660,7 +661,7 @@ func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, s | |||
| 			ErrorDescription: "cannot proceed your request", | ||||
| 		}) | ||||
| 	} | ||||
| 	resp, tokenErr := newAccessTokenResponse(authorizationCode.Grant, signingKey) | ||||
| 	resp, tokenErr := newAccessTokenResponse(authorizationCode.Grant, serverKey, clientKey) | ||||
| 	if tokenErr != nil { | ||||
| 		handleAccessTokenError(ctx, *tokenErr) | ||||
| 		return | ||||
|  |  | |||
|  | @ -18,9 +18,8 @@ func createAndParseToken(t *testing.T, grant *models.OAuth2Grant) *oauth2.OIDCTo | |||
| 	signingKey, err := oauth2.CreateJWTSigningKey("HS256", make([]byte, 32)) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.NotNil(t, signingKey) | ||||
| 	oauth2.DefaultSigningKey = signingKey | ||||
| 
 | ||||
| 	response, terr := newAccessTokenResponse(grant, signingKey) | ||||
| 	response, terr := newAccessTokenResponse(grant, signingKey, signingKey) | ||||
| 	assert.Nil(t, terr) | ||||
| 	assert.NotNil(t, response) | ||||
| 
 | ||||
|  |  | |||
|  | @ -29,9 +29,9 @@ func CheckOAuthAccessToken(accessToken string) int64 { | |||
| 	if !strings.Contains(accessToken, ".") { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	token, err := oauth2.ParseToken(accessToken) | ||||
| 	token, err := oauth2.ParseToken(accessToken, oauth2.DefaultSigningKey) | ||||
| 	if err != nil { | ||||
| 		log.Trace("ParseOAuth2Token: %v", err) | ||||
| 		log.Trace("oauth2.ParseToken: %v", err) | ||||
| 		return 0 | ||||
| 	} | ||||
| 	var grant *models.OAuth2Grant | ||||
|  |  | |||
|  | @ -40,12 +40,12 @@ type Token struct { | |||
| } | ||||
| 
 | ||||
| // ParseToken parses a signed jwt string
 | ||||
| func ParseToken(jwtToken string) (*Token, error) { | ||||
| func ParseToken(jwtToken string, signingKey JWTSigningKey) (*Token, error) { | ||||
| 	parsedToken, err := jwt.ParseWithClaims(jwtToken, &Token{}, func(token *jwt.Token) (interface{}, error) { | ||||
| 		if token.Method == nil || token.Method.Alg() != DefaultSigningKey.SigningMethod().Alg() { | ||||
| 		if token.Method == nil || token.Method.Alg() != signingKey.SigningMethod().Alg() { | ||||
| 			return nil, fmt.Errorf("unexpected signing algo: %v", token.Header["alg"]) | ||||
| 		} | ||||
| 		return DefaultSigningKey.VerifyKey(), nil | ||||
| 		return signingKey.VerifyKey(), nil | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
|  | @ -59,11 +59,11 @@ func ParseToken(jwtToken string) (*Token, error) { | |||
| } | ||||
| 
 | ||||
| // SignToken signs the token with the JWT secret
 | ||||
| func (token *Token) SignToken() (string, error) { | ||||
| func (token *Token) SignToken(signingKey JWTSigningKey) (string, error) { | ||||
| 	token.IssuedAt = time.Now().Unix() | ||||
| 	jwtToken := jwt.NewWithClaims(DefaultSigningKey.SigningMethod(), token) | ||||
| 	DefaultSigningKey.PreProcessToken(jwtToken) | ||||
| 	return jwtToken.SignedString(DefaultSigningKey.SignKey()) | ||||
| 	jwtToken := jwt.NewWithClaims(signingKey.SigningMethod(), token) | ||||
| 	signingKey.PreProcessToken(jwtToken) | ||||
| 	return jwtToken.SignedString(signingKey.SignKey()) | ||||
| } | ||||
| 
 | ||||
| // OIDCToken represents an OpenID Connect id_token
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue