Decoupled code from DefaultSigningKey (#16743)

Decoupled code from `DefaultSigningKey`. Makes testing a little bit easier and is cleaner.
This commit is contained in:
KN4CK3R 2021-08-27 21:28:00 +02:00 committed by GitHub
parent cd8db3a83d
commit 88abb0dc8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 27 deletions

View File

@ -115,7 +115,7 @@ type AccessTokenResponse struct {
IDToken string `json:"id_token,omitempty"` IDToken string `json:"id_token,omitempty"`
} }
func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) { func newAccessTokenResponse(grant *models.OAuth2Grant, serverKey, clientKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) {
if setting.OAuth2.InvalidateRefreshTokens { if setting.OAuth2.InvalidateRefreshTokens {
if err := grant.IncreaseCounter(); err != nil { if err := grant.IncreaseCounter(); err != nil {
return nil, &AccessTokenError{ return nil, &AccessTokenError{
@ -133,7 +133,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign
ExpiresAt: expirationDate.AsTime().Unix(), ExpiresAt: expirationDate.AsTime().Unix(),
}, },
} }
signedAccessToken, err := accessToken.SignToken() signedAccessToken, err := accessToken.SignToken(serverKey)
if err != nil { if err != nil {
return nil, &AccessTokenError{ return nil, &AccessTokenError{
ErrorCode: AccessTokenErrorCodeInvalidRequest, ErrorCode: AccessTokenErrorCodeInvalidRequest,
@ -151,7 +151,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign
ExpiresAt: refreshExpirationDate, ExpiresAt: refreshExpirationDate,
}, },
} }
signedRefreshToken, err := refreshToken.SignToken() signedRefreshToken, err := refreshToken.SignToken(serverKey)
if err != nil { if err != nil {
return nil, &AccessTokenError{ return nil, &AccessTokenError{
ErrorCode: AccessTokenErrorCodeInvalidRequest, ErrorCode: AccessTokenErrorCodeInvalidRequest,
@ -207,7 +207,7 @@ func newAccessTokenResponse(grant *models.OAuth2Grant, signingKey oauth2.JWTSign
idToken.EmailVerified = user.IsActive idToken.EmailVerified = user.IsActive
} }
signedIDToken, err = idToken.SignToken(signingKey) signedIDToken, err = idToken.SignToken(clientKey)
if err != nil { if err != nil {
return nil, &AccessTokenError{ return nil, &AccessTokenError{
ErrorCode: AccessTokenErrorCodeInvalidRequest, ErrorCode: AccessTokenErrorCodeInvalidRequest,
@ -265,7 +265,7 @@ func IntrospectOAuth(ctx *context.Context) {
} }
form := web.GetForm(ctx).(*forms.IntrospectTokenForm) form := web.GetForm(ctx).(*forms.IntrospectTokenForm)
token, err := oauth2.ParseToken(form.Token) token, err := oauth2.ParseToken(form.Token, oauth2.DefaultSigningKey)
if err == nil { if err == nil {
if token.Valid() == nil { if token.Valid() == nil {
grant, err := models.GetOAuth2GrantByID(token.GrantID) grant, err := models.GetOAuth2GrantByID(token.GrantID)
@ -544,9 +544,11 @@ func AccessTokenOAuth(ctx *context.Context) {
} }
} }
signingKey := oauth2.DefaultSigningKey serverKey := oauth2.DefaultSigningKey
if signingKey.IsSymmetric() { clientKey := serverKey
clientKey, err := oauth2.CreateJWTSigningKey(signingKey.SigningMethod().Alg(), []byte(form.ClientSecret)) if serverKey.IsSymmetric() {
var err error
clientKey, err = oauth2.CreateJWTSigningKey(serverKey.SigningMethod().Alg(), []byte(form.ClientSecret))
if err != nil { if err != nil {
handleAccessTokenError(ctx, AccessTokenError{ handleAccessTokenError(ctx, AccessTokenError{
ErrorCode: AccessTokenErrorCodeInvalidRequest, ErrorCode: AccessTokenErrorCodeInvalidRequest,
@ -554,14 +556,13 @@ func AccessTokenOAuth(ctx *context.Context) {
}) })
return return
} }
signingKey = clientKey
} }
switch form.GrantType { switch form.GrantType {
case "refresh_token": case "refresh_token":
handleRefreshToken(ctx, form, signingKey) handleRefreshToken(ctx, form, serverKey, clientKey)
case "authorization_code": case "authorization_code":
handleAuthorizationCode(ctx, form, signingKey) handleAuthorizationCode(ctx, form, serverKey, clientKey)
default: default:
handleAccessTokenError(ctx, AccessTokenError{ handleAccessTokenError(ctx, AccessTokenError{
ErrorCode: AccessTokenErrorCodeUnsupportedGrantType, ErrorCode: AccessTokenErrorCodeUnsupportedGrantType,
@ -570,8 +571,8 @@ func AccessTokenOAuth(ctx *context.Context) {
} }
} }
func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signingKey oauth2.JWTSigningKey) { func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, serverKey, clientKey oauth2.JWTSigningKey) {
token, err := oauth2.ParseToken(form.RefreshToken) token, err := oauth2.ParseToken(form.RefreshToken, serverKey)
if err != nil { if err != nil {
handleAccessTokenError(ctx, AccessTokenError{ handleAccessTokenError(ctx, AccessTokenError{
ErrorCode: AccessTokenErrorCodeUnauthorizedClient, ErrorCode: AccessTokenErrorCodeUnauthorizedClient,
@ -598,7 +599,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signin
log.Warn("A client tried to use a refresh token for grant_id = %d was used twice!", grant.ID) log.Warn("A client tried to use a refresh token for grant_id = %d was used twice!", grant.ID)
return return
} }
accessToken, tokenErr := newAccessTokenResponse(grant, signingKey) accessToken, tokenErr := newAccessTokenResponse(grant, serverKey, clientKey)
if tokenErr != nil { if tokenErr != nil {
handleAccessTokenError(ctx, *tokenErr) handleAccessTokenError(ctx, *tokenErr)
return return
@ -606,7 +607,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, signin
ctx.JSON(http.StatusOK, accessToken) ctx.JSON(http.StatusOK, accessToken)
} }
func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, signingKey oauth2.JWTSigningKey) { func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, serverKey, clientKey oauth2.JWTSigningKey) {
app, err := models.GetOAuth2ApplicationByClientID(form.ClientID) app, err := models.GetOAuth2ApplicationByClientID(form.ClientID)
if err != nil { if err != nil {
handleAccessTokenError(ctx, AccessTokenError{ handleAccessTokenError(ctx, AccessTokenError{
@ -660,7 +661,7 @@ func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, s
ErrorDescription: "cannot proceed your request", ErrorDescription: "cannot proceed your request",
}) })
} }
resp, tokenErr := newAccessTokenResponse(authorizationCode.Grant, signingKey) resp, tokenErr := newAccessTokenResponse(authorizationCode.Grant, serverKey, clientKey)
if tokenErr != nil { if tokenErr != nil {
handleAccessTokenError(ctx, *tokenErr) handleAccessTokenError(ctx, *tokenErr)
return return

View File

@ -18,9 +18,8 @@ func createAndParseToken(t *testing.T, grant *models.OAuth2Grant) *oauth2.OIDCTo
signingKey, err := oauth2.CreateJWTSigningKey("HS256", make([]byte, 32)) signingKey, err := oauth2.CreateJWTSigningKey("HS256", make([]byte, 32))
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, signingKey) assert.NotNil(t, signingKey)
oauth2.DefaultSigningKey = signingKey
response, terr := newAccessTokenResponse(grant, signingKey) response, terr := newAccessTokenResponse(grant, signingKey, signingKey)
assert.Nil(t, terr) assert.Nil(t, terr)
assert.NotNil(t, response) assert.NotNil(t, response)

View File

@ -29,9 +29,9 @@ func CheckOAuthAccessToken(accessToken string) int64 {
if !strings.Contains(accessToken, ".") { if !strings.Contains(accessToken, ".") {
return 0 return 0
} }
token, err := oauth2.ParseToken(accessToken) token, err := oauth2.ParseToken(accessToken, oauth2.DefaultSigningKey)
if err != nil { if err != nil {
log.Trace("ParseOAuth2Token: %v", err) log.Trace("oauth2.ParseToken: %v", err)
return 0 return 0
} }
var grant *models.OAuth2Grant var grant *models.OAuth2Grant

View File

@ -40,12 +40,12 @@ type Token struct {
} }
// ParseToken parses a signed jwt string // ParseToken parses a signed jwt string
func ParseToken(jwtToken string) (*Token, error) { func ParseToken(jwtToken string, signingKey JWTSigningKey) (*Token, error) {
parsedToken, err := jwt.ParseWithClaims(jwtToken, &Token{}, func(token *jwt.Token) (interface{}, error) { parsedToken, err := jwt.ParseWithClaims(jwtToken, &Token{}, func(token *jwt.Token) (interface{}, error) {
if token.Method == nil || token.Method.Alg() != DefaultSigningKey.SigningMethod().Alg() { if token.Method == nil || token.Method.Alg() != signingKey.SigningMethod().Alg() {
return nil, fmt.Errorf("unexpected signing algo: %v", token.Header["alg"]) return nil, fmt.Errorf("unexpected signing algo: %v", token.Header["alg"])
} }
return DefaultSigningKey.VerifyKey(), nil return signingKey.VerifyKey(), nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -59,11 +59,11 @@ func ParseToken(jwtToken string) (*Token, error) {
} }
// SignToken signs the token with the JWT secret // SignToken signs the token with the JWT secret
func (token *Token) SignToken() (string, error) { func (token *Token) SignToken(signingKey JWTSigningKey) (string, error) {
token.IssuedAt = time.Now().Unix() token.IssuedAt = time.Now().Unix()
jwtToken := jwt.NewWithClaims(DefaultSigningKey.SigningMethod(), token) jwtToken := jwt.NewWithClaims(signingKey.SigningMethod(), token)
DefaultSigningKey.PreProcessToken(jwtToken) signingKey.PreProcessToken(jwtToken)
return jwtToken.SignedString(DefaultSigningKey.SignKey()) return jwtToken.SignedString(signingKey.SignKey())
} }
// OIDCToken represents an OpenID Connect id_token // OIDCToken represents an OpenID Connect id_token