Commit a120939e authored by Pål Karlsrud's avatar Pål Karlsrud
Browse files

Add the ability to override the client ID which is used as the

audience of the JWToken.

Currently, the client ID obtained through the X-Dataporten-Clientid
header is always used.
In order to make it easier to associated JWTokens with specific
Kubernetes clusters, we want to make it possible to always use
a specific client ID as the audience.
parent ffefcfd7
......@@ -4,7 +4,8 @@
"groups_endpoint": "https://groups-api.example.no/groups/me/groups",
"basic_auth": {
"dataporten_creds": "username:password"
}
},
"client_id": ""
},
"issuer_url": "https://jwt.example.no",
"authorization_endpoint": "https://auth.example.no/oauth/authorization",
......
......@@ -39,7 +39,7 @@ func newJWTMiddleWare(keyPath string) (*JWTMiddleware, error) {
}, nil
}
func (jwm *JWTMiddleware) JWTTokenHandler() http.Handler {
func (jwm *JWTMiddleware) JWTTokenHandler(clientIDOverride string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Fetch user and groups from context
user := r.Context().Value(auth.User).(string)
......@@ -56,9 +56,19 @@ func (jwm *JWTMiddleware) JWTTokenHandler() http.Handler {
principals = append(principals, jwm.getMASgroups(r.Header.Get("X-Dataporten-Userid-Sec"))...)
}
// Get Client ID from request headers
clientID, clientIDFound := r.Header["X-Dataporten-Clientid"]
if !clientIDFound {
// If the admin wants to use a specific client ID as the
// audience, always use this instead of what is specified
// through the "X-Dataporten-Clientid" header.
var clientID string
if clientIDOverride != "" {
clientID = clientIDOverride
} else {
// Fallback to the client ID from request headers if the
// admin has not specified a client ID.
clientID = r.Header.Get("X-Dataporten-Clientid")
}
if clientID == "" {
auth.ReturnError(w, r, "No audience", http.StatusBadRequest)
return
}
......@@ -67,12 +77,12 @@ func (jwm *JWTMiddleware) JWTTokenHandler() http.Handler {
email := jwm.getEmail(r.Header.Get("X-Dataporten-Token"))
// Get the JWT token with all the required information
token := jwm.getJWTToken(w, r, user, principals, clientID[0], email)
token := jwm.getJWTToken(w, r, user, principals, clientID, email)
if token == nil {
return
}
log.Info("Issued token for User: ", user, " using App: ", clientID[0])
log.Info("Issued token for User: ", user, " using App: ", clientID)
// We have the token, so send it as json
w.Header().Set("Content-Type", "text/json; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
......
......@@ -41,19 +41,25 @@ func TestGetJWTToken(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, w.Code)
}
func TestVerifyGoodToken(t *testing.T) {
func testFetchMockToken(t *testing.T, url, user, clientID string, expectedClientID string, setClientIDHeader bool, groups []string) error {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "http://example.com/foo", nil)
r := httptest.NewRequest("GET", url, nil)
bytes, _ := ioutil.ReadFile("./test/key.pub")
rsaPublic, _ := crypto.ParseRSAPublicKeyFromPEM(bytes)
ctx := r.Context()
ctx = context.WithValue(ctx, auth.User, "dummy")
ctx = context.WithValue(ctx, auth.Groups, []string{"dummygroup"})
ctx = context.WithValue(ctx, auth.User, user)
ctx = context.WithValue(ctx, auth.Groups, groups)
r = r.WithContext(ctx)
r.Header["X-Dataporten-Clientid"] = []string{"dummyapp"}
jwm.JWTTokenHandler().ServeHTTP(w, r)
var clientIDOverride = ""
if setClientIDHeader {
r.Header["X-Dataporten-Clientid"] = []string{clientID}
} else {
clientIDOverride = clientID
}
jwm.JWTTokenHandler(clientIDOverride).ServeHTTP(w, r)
assert.Equal(t, http.StatusCreated, w.Code)
body, err := ioutil.ReadAll(w.Body)
......@@ -65,7 +71,7 @@ func TestVerifyGoodToken(t *testing.T) {
assert.Nil(t, err)
claims := jwt.Claims{}
claims.SetAudience(r.Header.Get("X-Dataporten-Clientid"))
claims.SetAudience(expectedClientID)
claims.SetIssuer(conf.GetStringValue("engine.issuer_url"))
var validator = &jwt.Validator{
Expected: claims,
......@@ -74,44 +80,39 @@ func TestVerifyGoodToken(t *testing.T) {
Fn: nil,
}
err = parseJWT.Validate(rsaPublic, crypto.SigningMethodRS256, validator)
assert.Nil(t, err)
return parseJWT.Validate(rsaPublic, crypto.SigningMethodRS256, validator)
}
func TestVerifyBadToken(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "http://example.com/foo", nil)
bytes, _ := ioutil.ReadFile("./test/key.pub")
rsaPublic, _ := crypto.ParseRSAPublicKeyFromPEM(bytes)
ctx := r.Context()
ctx = context.WithValue(ctx, auth.User, "dummy")
ctx = context.WithValue(ctx, auth.Groups, "dummyPp")
r = r.WithContext(ctx)
r.Header["X-Dataporten-Clientid"] = []string{"dummyapp"}
func TestVerifyGoodToken(t *testing.T) {
url := "http://example.com/foo"
user := "dummy"
groups := []string{"dummygroup"}
jwm.JWTTokenHandler().ServeHTTP(w, r)
assert.Equal(t, http.StatusCreated, w.Code)
clientID := "dummyapp"
body, err := ioutil.ReadAll(w.Body)
// Ensure that it is possible to use a specific client ID, instead of the one provided through the header.
err := testFetchMockToken(t, url, user, clientID, clientID, true, groups)
assert.Nil(t, err)
var jwtToken JWTToken
err = json.Unmarshal(body, &jwtToken)
assert.Nil(t, err)
parseJWT, err := jws.ParseJWT([]byte(jwtToken.Token))
// Make sure that it is still possible to fallback to the client ID
// obtained through the X-Dataporten-Clientid header.
err = testFetchMockToken(t, url, user, clientID, clientID, false, groups)
assert.Nil(t, err)
}
claims := jwt.Claims{}
claims.SetAudience("invaliddummyapp")
claims.SetIssuer(conf.GetStringValue("engine.issuer_url"))
var validator = &jwt.Validator{
Expected: claims,
EXP: 0,
NBF: 10 * time.Second,
Fn: nil,
}
func TestVerifyBadToken(t *testing.T) {
url := "http://example.com/foo"
user := "dummy"
groups := []string{"dummyPp"}
clientID := "dummyapp"
expectedClientID := "invaliddummyapp"
err := testFetchMockToken(t, url, user, clientID, expectedClientID, true, groups)
assert.NotNil(t, err)
assert.Equal(t, "claim \"aud\" is invalid", err.Error())
err = parseJWT.Validate(rsaPublic, crypto.SigningMethodRS256, validator)
err = testFetchMockToken(t, url, user, clientID, expectedClientID, false, groups)
assert.NotNil(t, err)
assert.Equal(t, "claim \"aud\" is invalid", err.Error())
}
......@@ -117,7 +117,9 @@ func main() {
http.Handle("/healthz", healthzHandler())
http.Handle("/.well-known/openid-configuration", oidcConf.openidConfigHandler())
http.Handle("/jwks", jwks.jwksHandler())
http.Handle("/", auth.MiddlewareHandler(jwm.JWTTokenHandler()))
clientIDOverride := conf.GetStringValue("engine.dataporten.client_id")
http.Handle("/", auth.MiddlewareHandler(jwm.JWTTokenHandler(clientIDOverride)))
startTime = time.Now()
port := conf.GetIntValue("server.port")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment