jwt_test.go 3.34 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
package main

import (
	"context"
	"encoding/json"
	"io/ioutil"
	"net/http"
	"net/http/httptest"
	"testing"
	"time"

	"github.com/SermoDigital/jose/crypto"
	"github.com/SermoDigital/jose/jws"
	"github.com/SermoDigital/jose/jwt"
	"github.com/stretchr/testify/assert"
16
	"scm.uninett.no/daas/jwt-tokenissuer/conf"
17
18
19
20
21
22
23
	auth "scm.uninett.no/laas/laasctl-auth"
)

type JWTToken struct {
	Token string `json:"token"`
}

24
var jwm, _ = newJWTMiddleWare("./test/key.priv")
25
26
27
28
29
30
31

// Tests
func TestGetJWTToken(t *testing.T) {

	w := httptest.NewRecorder()
	r := httptest.NewRequest("GET", "http://example.com/foo", nil)

32
	token := jwm.getJWTToken(w, r, "dummy", []string{"dummyPp"}, "dummyapp", "test@test.com")
33
	assert.NotNil(t, token)
34
	token = jwm.getJWTToken(w, r, "dummy", nil, "dummyapp", "test@test.com")
35
	assert.NotNil(t, token)
36
	token = jwm.getJWTToken(w, r, "dummy", nil, "", "test@test.com")
37
38
	assert.Nil(t, token)
	assert.Equal(t, http.StatusBadRequest, w.Code)
39
	token = jwm.getJWTToken(w, r, "", nil, "", "test@test.com")
40
41
42
43
	assert.Nil(t, token)
	assert.Equal(t, http.StatusBadRequest, w.Code)
}

44
func testFetchMockToken(t *testing.T, url, user, clientID string, expectedClientID string, setClientIDHeader bool, groups []string) error {
45
	w := httptest.NewRecorder()
46
	r := httptest.NewRequest("GET", url, nil)
47
	bytes, _ := ioutil.ReadFile("./test/key.pub")
48
49
50
	rsaPublic, _ := crypto.ParseRSAPublicKeyFromPEM(bytes)

	ctx := r.Context()
51
52
	ctx = context.WithValue(ctx, auth.User, user)
	ctx = context.WithValue(ctx, auth.Groups, groups)
53
54
	r = r.WithContext(ctx)

55
56
57
58
59
60
61
62
	var clientIDOverride = ""
	if setClientIDHeader {
		r.Header["X-Dataporten-Clientid"] = []string{clientID}
	} else {
		clientIDOverride = clientID
	}

	jwm.JWTTokenHandler(clientIDOverride).ServeHTTP(w, r)
63
64
65
66
67
68
69
70
71
72
73
	assert.Equal(t, http.StatusCreated, w.Code)

	body, err := ioutil.ReadAll(w.Body)
	assert.Nil(t, err)
	var jwtToken JWTToken
	err = json.Unmarshal(body, &jwtToken)
	assert.Nil(t, err)
	parseJWT, err := jws.ParseJWT([]byte(jwtToken.Token))
	assert.Nil(t, err)

	claims := jwt.Claims{}
74
	claims.SetAudience(expectedClientID)
75
76
77
78
79
80
81
82
	claims.SetIssuer(conf.GetStringValue("engine.issuer_url"))
	var validator = &jwt.Validator{
		Expected: claims,
		EXP:      0,
		NBF:      10 * time.Second,
		Fn:       nil,
	}

83
	return parseJWT.Validate(rsaPublic, crypto.SigningMethodRS256, validator)
84
85
}

86
87
88
89
func TestVerifyGoodToken(t *testing.T) {
	url := "http://example.com/foo"
	user := "dummy"
	groups := []string{"dummygroup"}
90

91
	clientID := "dummyapp"
92

93
94
	// Ensure that it is possible to use a specific client ID, instead of the one provided through the header.
	err := testFetchMockToken(t, url, user, clientID, clientID, true, groups)
95
	assert.Nil(t, err)
96
97
98
99

	// Make sure that it is still possible to fallback to the client ID
	// obtained through the X-Dataporten-Clientid header.
	err = testFetchMockToken(t, url, user, clientID, clientID, false, groups)
100
	assert.Nil(t, err)
101
}
102

103
104
105
106
107
108
109
110
111
112
113
func TestVerifyBadToken(t *testing.T) {
	url := "http://example.com/foo"
	user := "dummy"
	groups := []string{"dummyPp"}

	clientID := "dummyapp"
	expectedClientID := "invaliddummyapp"

	err := testFetchMockToken(t, url, user, clientID, expectedClientID, true, groups)
	assert.NotNil(t, err)
	assert.Equal(t, "claim \"aud\" is invalid", err.Error())
114

115
	err = testFetchMockToken(t, url, user, clientID, expectedClientID, false, groups)
116
117
118
	assert.NotNil(t, err)
	assert.Equal(t, "claim \"aud\" is invalid", err.Error())
}