Commit a120939e authored by Pål Karlsrud's avatar Pål Karlsrud
Browse files

Add the ability to override the client ID which is used as the

audience of the JWToken.

Currently, the client ID obtained through the X-Dataporten-Clientid
header is always used.
In order to make it easier to associated JWTokens with specific
Kubernetes clusters, we want to make it possible to always use
a specific client ID as the audience.
parent ffefcfd7
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
"groups_endpoint": "https://groups-api.example.no/groups/me/groups", "groups_endpoint": "https://groups-api.example.no/groups/me/groups",
"basic_auth": { "basic_auth": {
"dataporten_creds": "username:password" "dataporten_creds": "username:password"
} },
"client_id": ""
}, },
"issuer_url": "https://jwt.example.no", "issuer_url": "https://jwt.example.no",
"authorization_endpoint": "https://auth.example.no/oauth/authorization", "authorization_endpoint": "https://auth.example.no/oauth/authorization",
......
...@@ -39,7 +39,7 @@ func newJWTMiddleWare(keyPath string) (*JWTMiddleware, error) { ...@@ -39,7 +39,7 @@ func newJWTMiddleWare(keyPath string) (*JWTMiddleware, error) {
}, nil }, nil
} }
func (jwm *JWTMiddleware) JWTTokenHandler() http.Handler { func (jwm *JWTMiddleware) JWTTokenHandler(clientIDOverride string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Fetch user and groups from context // Fetch user and groups from context
user := r.Context().Value(auth.User).(string) user := r.Context().Value(auth.User).(string)
...@@ -56,9 +56,19 @@ func (jwm *JWTMiddleware) JWTTokenHandler() http.Handler { ...@@ -56,9 +56,19 @@ func (jwm *JWTMiddleware) JWTTokenHandler() http.Handler {
principals = append(principals, jwm.getMASgroups(r.Header.Get("X-Dataporten-Userid-Sec"))...) principals = append(principals, jwm.getMASgroups(r.Header.Get("X-Dataporten-Userid-Sec"))...)
} }
// Get Client ID from request headers // If the admin wants to use a specific client ID as the
clientID, clientIDFound := r.Header["X-Dataporten-Clientid"] // audience, always use this instead of what is specified
if !clientIDFound { // through the "X-Dataporten-Clientid" header.
var clientID string
if clientIDOverride != "" {
clientID = clientIDOverride
} else {
// Fallback to the client ID from request headers if the
// admin has not specified a client ID.
clientID = r.Header.Get("X-Dataporten-Clientid")
}
if clientID == "" {
auth.ReturnError(w, r, "No audience", http.StatusBadRequest) auth.ReturnError(w, r, "No audience", http.StatusBadRequest)
return return
} }
...@@ -67,12 +77,12 @@ func (jwm *JWTMiddleware) JWTTokenHandler() http.Handler { ...@@ -67,12 +77,12 @@ func (jwm *JWTMiddleware) JWTTokenHandler() http.Handler {
email := jwm.getEmail(r.Header.Get("X-Dataporten-Token")) email := jwm.getEmail(r.Header.Get("X-Dataporten-Token"))
// Get the JWT token with all the required information // Get the JWT token with all the required information
token := jwm.getJWTToken(w, r, user, principals, clientID[0], email) token := jwm.getJWTToken(w, r, user, principals, clientID, email)
if token == nil { if token == nil {
return return
} }
log.Info("Issued token for User: ", user, " using App: ", clientID[0]) log.Info("Issued token for User: ", user, " using App: ", clientID)
// We have the token, so send it as json // We have the token, so send it as json
w.Header().Set("Content-Type", "text/json; charset=utf-8") w.Header().Set("Content-Type", "text/json; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Content-Type-Options", "nosniff")
......
...@@ -41,19 +41,25 @@ func TestGetJWTToken(t *testing.T) { ...@@ -41,19 +41,25 @@ func TestGetJWTToken(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, w.Code) assert.Equal(t, http.StatusBadRequest, w.Code)
} }
func TestVerifyGoodToken(t *testing.T) { func testFetchMockToken(t *testing.T, url, user, clientID string, expectedClientID string, setClientIDHeader bool, groups []string) error {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "http://example.com/foo", nil) r := httptest.NewRequest("GET", url, nil)
bytes, _ := ioutil.ReadFile("./test/key.pub") bytes, _ := ioutil.ReadFile("./test/key.pub")
rsaPublic, _ := crypto.ParseRSAPublicKeyFromPEM(bytes) rsaPublic, _ := crypto.ParseRSAPublicKeyFromPEM(bytes)
ctx := r.Context() ctx := r.Context()
ctx = context.WithValue(ctx, auth.User, "dummy") ctx = context.WithValue(ctx, auth.User, user)
ctx = context.WithValue(ctx, auth.Groups, []string{"dummygroup"}) ctx = context.WithValue(ctx, auth.Groups, groups)
r = r.WithContext(ctx) r = r.WithContext(ctx)
r.Header["X-Dataporten-Clientid"] = []string{"dummyapp"}
jwm.JWTTokenHandler().ServeHTTP(w, r) var clientIDOverride = ""
if setClientIDHeader {
r.Header["X-Dataporten-Clientid"] = []string{clientID}
} else {
clientIDOverride = clientID
}
jwm.JWTTokenHandler(clientIDOverride).ServeHTTP(w, r)
assert.Equal(t, http.StatusCreated, w.Code) assert.Equal(t, http.StatusCreated, w.Code)
body, err := ioutil.ReadAll(w.Body) body, err := ioutil.ReadAll(w.Body)
...@@ -65,7 +71,7 @@ func TestVerifyGoodToken(t *testing.T) { ...@@ -65,7 +71,7 @@ func TestVerifyGoodToken(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
claims := jwt.Claims{} claims := jwt.Claims{}
claims.SetAudience(r.Header.Get("X-Dataporten-Clientid")) claims.SetAudience(expectedClientID)
claims.SetIssuer(conf.GetStringValue("engine.issuer_url")) claims.SetIssuer(conf.GetStringValue("engine.issuer_url"))
var validator = &jwt.Validator{ var validator = &jwt.Validator{
Expected: claims, Expected: claims,
...@@ -74,44 +80,39 @@ func TestVerifyGoodToken(t *testing.T) { ...@@ -74,44 +80,39 @@ func TestVerifyGoodToken(t *testing.T) {
Fn: nil, Fn: nil,
} }
err = parseJWT.Validate(rsaPublic, crypto.SigningMethodRS256, validator) return parseJWT.Validate(rsaPublic, crypto.SigningMethodRS256, validator)
assert.Nil(t, err)
} }
func TestVerifyBadToken(t *testing.T) { func TestVerifyGoodToken(t *testing.T) {
w := httptest.NewRecorder() url := "http://example.com/foo"
r := httptest.NewRequest("GET", "http://example.com/foo", nil) user := "dummy"
bytes, _ := ioutil.ReadFile("./test/key.pub") groups := []string{"dummygroup"}
rsaPublic, _ := crypto.ParseRSAPublicKeyFromPEM(bytes)
ctx := r.Context()
ctx = context.WithValue(ctx, auth.User, "dummy")
ctx = context.WithValue(ctx, auth.Groups, "dummyPp")
r = r.WithContext(ctx)
r.Header["X-Dataporten-Clientid"] = []string{"dummyapp"}
jwm.JWTTokenHandler().ServeHTTP(w, r) clientID := "dummyapp"
assert.Equal(t, http.StatusCreated, w.Code)
body, err := ioutil.ReadAll(w.Body) // Ensure that it is possible to use a specific client ID, instead of the one provided through the header.
err := testFetchMockToken(t, url, user, clientID, clientID, true, groups)
assert.Nil(t, err) assert.Nil(t, err)
var jwtToken JWTToken
err = json.Unmarshal(body, &jwtToken) // Make sure that it is still possible to fallback to the client ID
assert.Nil(t, err) // obtained through the X-Dataporten-Clientid header.
parseJWT, err := jws.ParseJWT([]byte(jwtToken.Token)) err = testFetchMockToken(t, url, user, clientID, clientID, false, groups)
assert.Nil(t, err) assert.Nil(t, err)
}
claims := jwt.Claims{} func TestVerifyBadToken(t *testing.T) {
claims.SetAudience("invaliddummyapp") url := "http://example.com/foo"
claims.SetIssuer(conf.GetStringValue("engine.issuer_url")) user := "dummy"
var validator = &jwt.Validator{ groups := []string{"dummyPp"}
Expected: claims,
EXP: 0, clientID := "dummyapp"
NBF: 10 * time.Second, expectedClientID := "invaliddummyapp"
Fn: nil,
} err := testFetchMockToken(t, url, user, clientID, expectedClientID, true, groups)
assert.NotNil(t, err)
assert.Equal(t, "claim \"aud\" is invalid", err.Error())
err = parseJWT.Validate(rsaPublic, crypto.SigningMethodRS256, validator) err = testFetchMockToken(t, url, user, clientID, expectedClientID, false, groups)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Equal(t, "claim \"aud\" is invalid", err.Error()) assert.Equal(t, "claim \"aud\" is invalid", err.Error())
} }
...@@ -117,7 +117,9 @@ func main() { ...@@ -117,7 +117,9 @@ func main() {
http.Handle("/healthz", healthzHandler()) http.Handle("/healthz", healthzHandler())
http.Handle("/.well-known/openid-configuration", oidcConf.openidConfigHandler()) http.Handle("/.well-known/openid-configuration", oidcConf.openidConfigHandler())
http.Handle("/jwks", jwks.jwksHandler()) http.Handle("/jwks", jwks.jwksHandler())
http.Handle("/", auth.MiddlewareHandler(jwm.JWTTokenHandler()))
clientIDOverride := conf.GetStringValue("engine.dataporten.client_id")
http.Handle("/", auth.MiddlewareHandler(jwm.JWTTokenHandler(clientIDOverride)))
startTime = time.Now() startTime = time.Now()
port := conf.GetIntValue("server.port") port := conf.GetIntValue("server.port")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment