Commit 36da8fae authored by Gurvinder Singh's avatar Gurvinder Singh
Browse files

Merge branch 'master' into 'master'

Add the ability to override the client ID which is used as the audience of the JWToken

See merge request !1
parents ffefcfd7 a120939e
......@@ -4,7 +4,8 @@
"groups_endpoint": "https://groups-api.example.no/groups/me/groups",
"basic_auth": {
"dataporten_creds": "username:password"
}
},
"client_id": ""
},
"issuer_url": "https://jwt.example.no",
"authorization_endpoint": "https://auth.example.no/oauth/authorization",
......
......@@ -39,7 +39,7 @@ func newJWTMiddleWare(keyPath string) (*JWTMiddleware, error) {
}, nil
}
func (jwm *JWTMiddleware) JWTTokenHandler() http.Handler {
func (jwm *JWTMiddleware) JWTTokenHandler(clientIDOverride string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Fetch user and groups from context
user := r.Context().Value(auth.User).(string)
......@@ -56,9 +56,19 @@ func (jwm *JWTMiddleware) JWTTokenHandler() http.Handler {
principals = append(principals, jwm.getMASgroups(r.Header.Get("X-Dataporten-Userid-Sec"))...)
}
// Get Client ID from request headers
clientID, clientIDFound := r.Header["X-Dataporten-Clientid"]
if !clientIDFound {
// If the admin wants to use a specific client ID as the
// audience, always use this instead of what is specified
// through the "X-Dataporten-Clientid" header.
var clientID string
if clientIDOverride != "" {
clientID = clientIDOverride
} else {
// Fallback to the client ID from request headers if the
// admin has not specified a client ID.
clientID = r.Header.Get("X-Dataporten-Clientid")
}
if clientID == "" {
auth.ReturnError(w, r, "No audience", http.StatusBadRequest)
return
}
......@@ -67,12 +77,12 @@ func (jwm *JWTMiddleware) JWTTokenHandler() http.Handler {
email := jwm.getEmail(r.Header.Get("X-Dataporten-Token"))
// Get the JWT token with all the required information
token := jwm.getJWTToken(w, r, user, principals, clientID[0], email)
token := jwm.getJWTToken(w, r, user, principals, clientID, email)
if token == nil {
return
}
log.Info("Issued token for User: ", user, " using App: ", clientID[0])
log.Info("Issued token for User: ", user, " using App: ", clientID)
// We have the token, so send it as json
w.Header().Set("Content-Type", "text/json; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
......
......@@ -41,19 +41,25 @@ func TestGetJWTToken(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, w.Code)
}
func TestVerifyGoodToken(t *testing.T) {
func testFetchMockToken(t *testing.T, url, user, clientID string, expectedClientID string, setClientIDHeader bool, groups []string) error {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "http://example.com/foo", nil)
r := httptest.NewRequest("GET", url, nil)
bytes, _ := ioutil.ReadFile("./test/key.pub")
rsaPublic, _ := crypto.ParseRSAPublicKeyFromPEM(bytes)
ctx := r.Context()
ctx = context.WithValue(ctx, auth.User, "dummy")
ctx = context.WithValue(ctx, auth.Groups, []string{"dummygroup"})
ctx = context.WithValue(ctx, auth.User, user)
ctx = context.WithValue(ctx, auth.Groups, groups)
r = r.WithContext(ctx)
r.Header["X-Dataporten-Clientid"] = []string{"dummyapp"}
jwm.JWTTokenHandler().ServeHTTP(w, r)
var clientIDOverride = ""
if setClientIDHeader {
r.Header["X-Dataporten-Clientid"] = []string{clientID}
} else {
clientIDOverride = clientID
}
jwm.JWTTokenHandler(clientIDOverride).ServeHTTP(w, r)
assert.Equal(t, http.StatusCreated, w.Code)
body, err := ioutil.ReadAll(w.Body)
......@@ -65,7 +71,7 @@ func TestVerifyGoodToken(t *testing.T) {
assert.Nil(t, err)
claims := jwt.Claims{}
claims.SetAudience(r.Header.Get("X-Dataporten-Clientid"))
claims.SetAudience(expectedClientID)
claims.SetIssuer(conf.GetStringValue("engine.issuer_url"))
var validator = &jwt.Validator{
Expected: claims,
......@@ -74,44 +80,39 @@ func TestVerifyGoodToken(t *testing.T) {
Fn: nil,
}
err = parseJWT.Validate(rsaPublic, crypto.SigningMethodRS256, validator)
assert.Nil(t, err)
return parseJWT.Validate(rsaPublic, crypto.SigningMethodRS256, validator)
}
func TestVerifyBadToken(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "http://example.com/foo", nil)
bytes, _ := ioutil.ReadFile("./test/key.pub")
rsaPublic, _ := crypto.ParseRSAPublicKeyFromPEM(bytes)
ctx := r.Context()
ctx = context.WithValue(ctx, auth.User, "dummy")
ctx = context.WithValue(ctx, auth.Groups, "dummyPp")
r = r.WithContext(ctx)
r.Header["X-Dataporten-Clientid"] = []string{"dummyapp"}
func TestVerifyGoodToken(t *testing.T) {
url := "http://example.com/foo"
user := "dummy"
groups := []string{"dummygroup"}
jwm.JWTTokenHandler().ServeHTTP(w, r)
assert.Equal(t, http.StatusCreated, w.Code)
clientID := "dummyapp"
body, err := ioutil.ReadAll(w.Body)
// Ensure that it is possible to use a specific client ID, instead of the one provided through the header.
err := testFetchMockToken(t, url, user, clientID, clientID, true, groups)
assert.Nil(t, err)
var jwtToken JWTToken
err = json.Unmarshal(body, &jwtToken)
assert.Nil(t, err)
parseJWT, err := jws.ParseJWT([]byte(jwtToken.Token))
// Make sure that it is still possible to fallback to the client ID
// obtained through the X-Dataporten-Clientid header.
err = testFetchMockToken(t, url, user, clientID, clientID, false, groups)
assert.Nil(t, err)
}
claims := jwt.Claims{}
claims.SetAudience("invaliddummyapp")
claims.SetIssuer(conf.GetStringValue("engine.issuer_url"))
var validator = &jwt.Validator{
Expected: claims,
EXP: 0,
NBF: 10 * time.Second,
Fn: nil,
}
func TestVerifyBadToken(t *testing.T) {
url := "http://example.com/foo"
user := "dummy"
groups := []string{"dummyPp"}
clientID := "dummyapp"
expectedClientID := "invaliddummyapp"
err := testFetchMockToken(t, url, user, clientID, expectedClientID, true, groups)
assert.NotNil(t, err)
assert.Equal(t, "claim \"aud\" is invalid", err.Error())
err = parseJWT.Validate(rsaPublic, crypto.SigningMethodRS256, validator)
err = testFetchMockToken(t, url, user, clientID, expectedClientID, false, groups)
assert.NotNil(t, err)
assert.Equal(t, "claim \"aud\" is invalid", err.Error())
}
......@@ -117,7 +117,9 @@ func main() {
http.Handle("/healthz", healthzHandler())
http.Handle("/.well-known/openid-configuration", oidcConf.openidConfigHandler())
http.Handle("/jwks", jwks.jwksHandler())
http.Handle("/", auth.MiddlewareHandler(jwm.JWTTokenHandler()))
clientIDOverride := conf.GetStringValue("engine.dataporten.client_id")
http.Handle("/", auth.MiddlewareHandler(jwm.JWTTokenHandler(clientIDOverride)))
startTime = time.Now()
port := conf.GetIntValue("server.port")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment