Commit 36da8fae authored by Gurvinder Singh's avatar Gurvinder Singh
Browse files

Merge branch 'master' into 'master'

Add the ability to override the client ID which is used as the audience of the JWToken

See merge request !1
parents ffefcfd7 a120939e
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
"groups_endpoint": "https://groups-api.example.no/groups/me/groups", "groups_endpoint": "https://groups-api.example.no/groups/me/groups",
"basic_auth": { "basic_auth": {
"dataporten_creds": "username:password" "dataporten_creds": "username:password"
} },
"client_id": ""
}, },
"issuer_url": "https://jwt.example.no", "issuer_url": "https://jwt.example.no",
"authorization_endpoint": "https://auth.example.no/oauth/authorization", "authorization_endpoint": "https://auth.example.no/oauth/authorization",
......
...@@ -39,7 +39,7 @@ func newJWTMiddleWare(keyPath string) (*JWTMiddleware, error) { ...@@ -39,7 +39,7 @@ func newJWTMiddleWare(keyPath string) (*JWTMiddleware, error) {
}, nil }, nil
} }
func (jwm *JWTMiddleware) JWTTokenHandler() http.Handler { func (jwm *JWTMiddleware) JWTTokenHandler(clientIDOverride string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Fetch user and groups from context // Fetch user and groups from context
user := r.Context().Value(auth.User).(string) user := r.Context().Value(auth.User).(string)
...@@ -56,9 +56,19 @@ func (jwm *JWTMiddleware) JWTTokenHandler() http.Handler { ...@@ -56,9 +56,19 @@ func (jwm *JWTMiddleware) JWTTokenHandler() http.Handler {
principals = append(principals, jwm.getMASgroups(r.Header.Get("X-Dataporten-Userid-Sec"))...) principals = append(principals, jwm.getMASgroups(r.Header.Get("X-Dataporten-Userid-Sec"))...)
} }
// Get Client ID from request headers // If the admin wants to use a specific client ID as the
clientID, clientIDFound := r.Header["X-Dataporten-Clientid"] // audience, always use this instead of what is specified
if !clientIDFound { // through the "X-Dataporten-Clientid" header.
var clientID string
if clientIDOverride != "" {
clientID = clientIDOverride
} else {
// Fallback to the client ID from request headers if the
// admin has not specified a client ID.
clientID = r.Header.Get("X-Dataporten-Clientid")
}
if clientID == "" {
auth.ReturnError(w, r, "No audience", http.StatusBadRequest) auth.ReturnError(w, r, "No audience", http.StatusBadRequest)
return return
} }
...@@ -67,12 +77,12 @@ func (jwm *JWTMiddleware) JWTTokenHandler() http.Handler { ...@@ -67,12 +77,12 @@ func (jwm *JWTMiddleware) JWTTokenHandler() http.Handler {
email := jwm.getEmail(r.Header.Get("X-Dataporten-Token")) email := jwm.getEmail(r.Header.Get("X-Dataporten-Token"))
// Get the JWT token with all the required information // Get the JWT token with all the required information
token := jwm.getJWTToken(w, r, user, principals, clientID[0], email) token := jwm.getJWTToken(w, r, user, principals, clientID, email)
if token == nil { if token == nil {
return return
} }
log.Info("Issued token for User: ", user, " using App: ", clientID[0]) log.Info("Issued token for User: ", user, " using App: ", clientID)
// We have the token, so send it as json // We have the token, so send it as json
w.Header().Set("Content-Type", "text/json; charset=utf-8") w.Header().Set("Content-Type", "text/json; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Content-Type-Options", "nosniff")
......
...@@ -41,19 +41,25 @@ func TestGetJWTToken(t *testing.T) { ...@@ -41,19 +41,25 @@ func TestGetJWTToken(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, w.Code) assert.Equal(t, http.StatusBadRequest, w.Code)
} }
func TestVerifyGoodToken(t *testing.T) { func testFetchMockToken(t *testing.T, url, user, clientID string, expectedClientID string, setClientIDHeader bool, groups []string) error {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "http://example.com/foo", nil) r := httptest.NewRequest("GET", url, nil)
bytes, _ := ioutil.ReadFile("./test/key.pub") bytes, _ := ioutil.ReadFile("./test/key.pub")
rsaPublic, _ := crypto.ParseRSAPublicKeyFromPEM(bytes) rsaPublic, _ := crypto.ParseRSAPublicKeyFromPEM(bytes)
ctx := r.Context() ctx := r.Context()
ctx = context.WithValue(ctx, auth.User, "dummy") ctx = context.WithValue(ctx, auth.User, user)
ctx = context.WithValue(ctx, auth.Groups, []string{"dummygroup"}) ctx = context.WithValue(ctx, auth.Groups, groups)
r = r.WithContext(ctx) r = r.WithContext(ctx)
r.Header["X-Dataporten-Clientid"] = []string{"dummyapp"}
jwm.JWTTokenHandler().ServeHTTP(w, r) var clientIDOverride = ""
if setClientIDHeader {
r.Header["X-Dataporten-Clientid"] = []string{clientID}
} else {
clientIDOverride = clientID
}
jwm.JWTTokenHandler(clientIDOverride).ServeHTTP(w, r)
assert.Equal(t, http.StatusCreated, w.Code) assert.Equal(t, http.StatusCreated, w.Code)
body, err := ioutil.ReadAll(w.Body) body, err := ioutil.ReadAll(w.Body)
...@@ -65,7 +71,7 @@ func TestVerifyGoodToken(t *testing.T) { ...@@ -65,7 +71,7 @@ func TestVerifyGoodToken(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
claims := jwt.Claims{} claims := jwt.Claims{}
claims.SetAudience(r.Header.Get("X-Dataporten-Clientid")) claims.SetAudience(expectedClientID)
claims.SetIssuer(conf.GetStringValue("engine.issuer_url")) claims.SetIssuer(conf.GetStringValue("engine.issuer_url"))
var validator = &jwt.Validator{ var validator = &jwt.Validator{
Expected: claims, Expected: claims,
...@@ -74,44 +80,39 @@ func TestVerifyGoodToken(t *testing.T) { ...@@ -74,44 +80,39 @@ func TestVerifyGoodToken(t *testing.T) {
Fn: nil, Fn: nil,
} }
err = parseJWT.Validate(rsaPublic, crypto.SigningMethodRS256, validator) return parseJWT.Validate(rsaPublic, crypto.SigningMethodRS256, validator)
assert.Nil(t, err)
} }
func TestVerifyBadToken(t *testing.T) { func TestVerifyGoodToken(t *testing.T) {
w := httptest.NewRecorder() url := "http://example.com/foo"
r := httptest.NewRequest("GET", "http://example.com/foo", nil) user := "dummy"
bytes, _ := ioutil.ReadFile("./test/key.pub") groups := []string{"dummygroup"}
rsaPublic, _ := crypto.ParseRSAPublicKeyFromPEM(bytes)
ctx := r.Context()
ctx = context.WithValue(ctx, auth.User, "dummy")
ctx = context.WithValue(ctx, auth.Groups, "dummyPp")
r = r.WithContext(ctx)
r.Header["X-Dataporten-Clientid"] = []string{"dummyapp"}
jwm.JWTTokenHandler().ServeHTTP(w, r) clientID := "dummyapp"
assert.Equal(t, http.StatusCreated, w.Code)
body, err := ioutil.ReadAll(w.Body) // Ensure that it is possible to use a specific client ID, instead of the one provided through the header.
err := testFetchMockToken(t, url, user, clientID, clientID, true, groups)
assert.Nil(t, err) assert.Nil(t, err)
var jwtToken JWTToken
err = json.Unmarshal(body, &jwtToken) // Make sure that it is still possible to fallback to the client ID
assert.Nil(t, err) // obtained through the X-Dataporten-Clientid header.
parseJWT, err := jws.ParseJWT([]byte(jwtToken.Token)) err = testFetchMockToken(t, url, user, clientID, clientID, false, groups)
assert.Nil(t, err) assert.Nil(t, err)
}
claims := jwt.Claims{} func TestVerifyBadToken(t *testing.T) {
claims.SetAudience("invaliddummyapp") url := "http://example.com/foo"
claims.SetIssuer(conf.GetStringValue("engine.issuer_url")) user := "dummy"
var validator = &jwt.Validator{ groups := []string{"dummyPp"}
Expected: claims,
EXP: 0, clientID := "dummyapp"
NBF: 10 * time.Second, expectedClientID := "invaliddummyapp"
Fn: nil,
} err := testFetchMockToken(t, url, user, clientID, expectedClientID, true, groups)
assert.NotNil(t, err)
assert.Equal(t, "claim \"aud\" is invalid", err.Error())
err = parseJWT.Validate(rsaPublic, crypto.SigningMethodRS256, validator) err = testFetchMockToken(t, url, user, clientID, expectedClientID, false, groups)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, "claim \"aud\" is invalid", err.Error()) assert.Equal(t, "claim \"aud\" is invalid", err.Error())
} }
...@@ -117,7 +117,9 @@ func main() { ...@@ -117,7 +117,9 @@ func main() {
http.Handle("/healthz", healthzHandler()) http.Handle("/healthz", healthzHandler())
http.Handle("/.well-known/openid-configuration", oidcConf.openidConfigHandler()) http.Handle("/.well-known/openid-configuration", oidcConf.openidConfigHandler())
http.Handle("/jwks", jwks.jwksHandler()) http.Handle("/jwks", jwks.jwksHandler())
http.Handle("/", auth.MiddlewareHandler(jwm.JWTTokenHandler()))
clientIDOverride := conf.GetStringValue("engine.dataporten.client_id")
http.Handle("/", auth.MiddlewareHandler(jwm.JWTTokenHandler(clientIDOverride)))
startTime = time.Now() startTime = time.Now()
port := conf.GetIntValue("server.port") port := conf.GetIntValue("server.port")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment