Commit 5ffb1376 authored by venaas's avatar venaas Committed by venaas
Browse files

fixed bug with multiple status server sent, some dtls fixes

git-svn-id: https://svn.testnett.uninett.no/radsecproxy/trunk@358 e88ac4ed-0b26-0410-9574-a7f39faa03bf
parent a2a0f702
......@@ -36,6 +36,18 @@
static int client4_sock = -1;
static int client6_sock = -1;
struct sessioncacheentry {
pthread_mutex_t mutex;
struct queue *rbios;
struct timeval expiry;
};
struct dtlsservernewparams {
struct sessioncacheentry *sesscache;
int sock;
struct sockaddr_storage addr;
};
int udp2bio(int s, struct queue *q, int cnt) {
unsigned char *buf;
BIO *rbio;
......@@ -92,115 +104,20 @@ BIO *getrbio(SSL *ssl, struct queue *q, int timeout) {
return rbio;
}
void *udpdtlsserverrd(void *arg) {
int cnt, s = *(int *)arg;
unsigned char buf[4];
struct sockaddr_storage from;
socklen_t fromlen = sizeof(from);
struct client *client;
fd_set readfds;
pthread_t dtlsserverth;
struct hash *rbiosh;
struct queue *rbiosq;
rbiosh = hash_create();
if (!rbiosh)
debugx(1, DBG_ERR, "udpdtlsserverrd: malloc failed");
for (;;) {
FD_ZERO(&readfds);
FD_SET(s, &readfds);
if (select(s + 1, &readfds, NULL, NULL, NULL) < 1)
continue;
cnt = recvfrom(s, buf, 4, MSG_PEEK | MSG_TRUNC, (struct sockaddr *)&from, &fromlen);
if (cnt == -1) {
debug(DBG_WARN, "udpdtlsserverrd: recv failed");
continue;
}
rbiosq = hash_read(rbiosh, &from, fromlen);
if (rbiosq) {
if (udp2bio(s, rbiosq, cnt))
debug(DBG_DBG, "udpdtlsserverrd: got DTLS in UDP from %s", addr2string((struct sockaddr *)&from, fromlen));
continue;
}
/* from new source, new client? */
client = malloc(sizeof(struct client));
if (!client)
continue;
client->rbios = rbiosq = newqueue();
if (!hash_insert(rbiosh, &from, fromlen, rbiosq)) {
free(client);
removequeue(rbiosq);
continue;
}
client->sock = s;
memcpy(&client->addr, &from, fromlen);
if (udp2bio(s, rbiosq, cnt)) {
debug(DBG_DBG, "udpdtlsserverrd: got DTLS in UDP from %s", addr2string((struct sockaddr *)&from, fromlen));
if (!pthread_create(&dtlsserverth, NULL, dtlsservernew, (void *)client)) {
pthread_detach(dtlsserverth);
continue;
}
debug(DBG_ERR, "udpdtlsserverrd: pthread_create failed");
}
free(client);
freebios(hash_extract(rbiosh, &from, fromlen));
}
}
void *dtlsserverwr(void *arg) {
int cnt;
unsigned long error;
struct client *client = (struct client *)arg;
struct queue *replyq;
struct reply *reply;
debug(DBG_DBG, "dtlsserverwr: starting for %s", client->conf->host);
replyq = client->replyq;
for (;;) {
pthread_mutex_lock(&replyq->mutex);
while (!list_first(replyq->entries)) {
if (client->ssl) {
debug(DBG_DBG, "dtlsserverwr: waiting for signal");
pthread_cond_wait(&replyq->cond, &replyq->mutex);
debug(DBG_DBG, "dtlsserverwr: got signal");
}
if (!client->ssl) {
/* ssl might have changed while waiting */
pthread_mutex_unlock(&replyq->mutex);
debug(DBG_DBG, "dtlsserverwr: exiting as requested");
ERR_remove_state(0);
pthread_exit(NULL);
}
}
reply = (struct reply *)list_shift(replyq->entries);
pthread_mutex_unlock(&replyq->mutex);
cnt = SSL_write(client->ssl, reply->buf, RADLEN(reply->buf));
if (cnt > 0)
debug(DBG_DBG, "dtlsserverwr: sent %d bytes, Radius packet of length %d",
cnt, RADLEN(reply->buf));
else
while ((error = ERR_get_error()))
debug(DBG_ERR, "dtlsserverwr: SSL: %s", ERR_error_string(error, NULL));
free(reply->buf);
free(reply);
}
}
int dtlsread(SSL *ssl, struct queue *q, unsigned char *buf, int num) {
int dtlsread(SSL *ssl, struct queue *q, unsigned char *buf, int num, int timeout) {
int len, cnt;
BIO *rbio;
for (len = 0; len < num; len += cnt) {
cnt = SSL_read(ssl, buf + len, num - len);
if (cnt <= 0)
switch (cnt = SSL_get_error(ssl, cnt)) {
case SSL_ERROR_WANT_READ:
rbio = getrbio(ssl, q, timeout);
if (!rbio)
return 0;
BIO_free(ssl->rbio);
ssl->rbio = getrbio(ssl, q, 0);
if (!ssl->rbio)
return -1;
ssl->rbio = rbio;
cnt = 0;
continue;
case SSL_ERROR_WANT_WRITE:
......@@ -217,14 +134,51 @@ int dtlsread(SSL *ssl, struct queue *q, unsigned char *buf, int num) {
return num;
}
unsigned char *raddtlsget(SSL *ssl, struct queue *rbios) {
/* accept if acc == 1, else connect */
SSL *dtlsacccon(uint8_t acc, SSL_CTX *ctx, int s, struct sockaddr *addr, struct queue *rbios) {
SSL *ssl;
int i, res;
unsigned long error;
BIO *mem0bio, *wbio;
ssl = SSL_new(ctx);
if (!ssl)
return NULL;
mem0bio = BIO_new(BIO_s_mem());
BIO_set_mem_eof_return(mem0bio, -1);
wbio = BIO_new_dgram(s, BIO_NOCLOSE);
BIO_dgram_set_peer(wbio, addr);
SSL_set_bio(ssl, mem0bio, wbio);
for (i = 0; i < 5; i++) {
res = acc ? SSL_accept(ssl) : SSL_connect(ssl);
if (res > 0)
return ssl;
if (res == 0)
break;
if (SSL_get_error(ssl, res) == SSL_ERROR_WANT_READ) {
BIO_free(ssl->rbio);
ssl->rbio = getrbio(ssl, rbios, 5);
if (!ssl->rbio)
break;
}
while ((error = ERR_get_error()))
debug(DBG_ERR, "dtls%st: DTLS: %s", acc ? "accep" : "connec", ERR_error_string(error, NULL));
}
SSL_free(ssl);
return NULL;
}
unsigned char *raddtlsget(SSL *ssl, struct queue *rbios, int timeout) {
int cnt, len;
unsigned char buf[4], *rad;
for (;;) {
cnt = dtlsread(ssl, rbios, buf, 4);
cnt = dtlsread(ssl, rbios, buf, 4, timeout);
if (cnt < 1) {
debug(DBG_DBG, "raddtlsget: connection lost");
debug(DBG_DBG, cnt ? "raddtlsget: connection lost" : "raddtlsget: timeout");
return NULL;
}
......@@ -236,9 +190,9 @@ unsigned char *raddtlsget(SSL *ssl, struct queue *rbios) {
}
memcpy(rad, buf, 4);
cnt = dtlsread(ssl, rbios, rad + 4, len - 4);
cnt = dtlsread(ssl, rbios, rad + 4, len - 4, timeout);
if (cnt < 1) {
debug(DBG_DBG, "raddtlsget: connection lost");
debug(DBG_DBG, cnt ? "raddtlsget: connection lost" : "raddtlsget: timeout");
free(rad);
return NULL;
}
......@@ -254,6 +208,45 @@ unsigned char *raddtlsget(SSL *ssl, struct queue *rbios) {
return rad;
}
void *dtlsserverwr(void *arg) {
int cnt;
unsigned long error;
struct client *client = (struct client *)arg;
struct queue *replyq;
struct reply *reply;
debug(DBG_DBG, "dtlsserverwr: starting for %s", client->conf->host);
replyq = client->replyq;
for (;;) {
pthread_mutex_lock(&replyq->mutex);
while (!list_first(replyq->entries)) {
if (client->ssl) {
debug(DBG_DBG, "dtlsserverwr: waiting for signal");
pthread_cond_wait(&replyq->cond, &replyq->mutex);
debug(DBG_DBG, "dtlsserverwr: got signal");
}
if (!client->ssl) {
/* ssl might have changed while waiting */
pthread_mutex_unlock(&replyq->mutex);
debug(DBG_DBG, "dtlsserverwr: exiting as requested");
ERR_remove_state(0);
pthread_exit(NULL);
}
}
reply = (struct reply *)list_shift(replyq->entries);
pthread_mutex_unlock(&replyq->mutex);
cnt = SSL_write(client->ssl, reply->buf, RADLEN(reply->buf));
if (cnt > 0)
debug(DBG_DBG, "dtlsserverwr: sent %d bytes, Radius packet of length %d",
cnt, RADLEN(reply->buf));
else
while ((error = ERR_get_error()))
debug(DBG_ERR, "dtlsserverwr: SSL: %s", ERR_error_string(error, NULL));
free(reply->buf);
free(reply);
}
}
void dtlsserverrd(struct client *client) {
struct request rq;
pthread_t dtlsserverwrth;
......@@ -267,7 +260,7 @@ void dtlsserverrd(struct client *client) {
for (;;) {
memset(&rq, 0, sizeof(struct request));
rq.buf = raddtlsget(client->ssl, client->rbios);
rq.buf = raddtlsget(client->ssl, client->rbios, IDLE_TIMEOUT);
if (!rq.buf) {
debug(DBG_ERR, "dtlsserverrd: connection from %s lost", client->conf->host);
break;
......@@ -292,61 +285,19 @@ void dtlsserverrd(struct client *client) {
debug(DBG_DBG, "dtlsserverrd: reader for %s exiting", client->conf->host);
}
/* accept if acc == 1, else connect */
SSL *dtlsacccon(uint8_t acc, SSL_CTX *ctx, int s, struct sockaddr *addr, struct queue *rbios) {
SSL *ssl;
int i, res;
unsigned long error;
BIO *mem0bio, *wbio;
ssl = SSL_new(ctx);
if (!ssl)
return NULL;
mem0bio = BIO_new(BIO_s_mem());
BIO_set_mem_eof_return(mem0bio, -1);
wbio = BIO_new_dgram(s, BIO_NOCLOSE);
BIO_dgram_set_peer(wbio, addr);
SSL_set_bio(ssl, mem0bio, wbio);
for (i = 0; i < 5; i++) {
res = acc ? SSL_accept(ssl) : SSL_connect(ssl);
if (res > 0)
return ssl;
if (res == 0)
break;
if (SSL_get_error(ssl, res) == SSL_ERROR_WANT_READ) {
BIO_free(ssl->rbio);
ssl->rbio = getrbio(ssl, rbios, 5);
if (!ssl->rbio)
break;
}
while ((error = ERR_get_error()))
debug(DBG_ERR, "dtls%st: DTLS: %s", acc ? "accep" : "connec", ERR_error_string(error, NULL));
}
SSL_free(ssl);
return NULL;
}
void *dtlsservernew(void *arg) {
struct client *client, *clpar = (struct client *)arg;
struct dtlsservernewparams *params = (struct dtlsservernewparams *)arg;
struct client *client;
struct clsrvconf *conf;
struct list_node *cur = NULL;
int s;
SSL *ssl = NULL;
X509 *cert = NULL;
struct queue *rbios;
struct sockaddr_storage addr;
s = clpar->sock;
rbios = clpar->rbios;
addr = clpar->addr;
free(clpar);
conf = find_clconf(RAD_DTLS, (struct sockaddr *)&addr, NULL);
uint8_t delay = 60;
debug(DBG_DBG, "dtlsservernew: starting");
conf = find_clconf(RAD_DTLS, (struct sockaddr *)&params->addr, NULL);
if (conf) {
ssl = dtlsacccon(1, conf->ssl_ctx, s, (struct sockaddr *)&addr, rbios);
ssl = dtlsacccon(1, conf->ssl_ctx, params->sock, (struct sockaddr *)&params->addr, params->sesscache->rbios);
if (!ssl)
goto exit;
cert = verifytlscert(ssl);
......@@ -359,18 +310,19 @@ void *dtlsservernew(void *arg) {
X509_free(cert);
client = addclient(conf);
if (client) {
client->sock = s;
client->rbios = rbios;
client->addr = addr;
client->sock = params->sock;
client->rbios = params->sesscache->rbios;
client->addr = params->addr;
client->ssl = ssl;
dtlsserverrd(client);
removeclient(client);
delay = 0;
} else {
debug(DBG_WARN, "dtlsservernew: failed to create new client instance");
}
goto exit;
}
conf = find_clconf(RAD_DTLS, (struct sockaddr *)&client->addr, &cur);
conf = find_clconf(RAD_DTLS, (struct sockaddr *)&params->addr, &cur);
}
debug(DBG_WARN, "dtlsservernew: ignoring request, no matching TLS client");
......@@ -378,10 +330,135 @@ void *dtlsservernew(void *arg) {
X509_free(cert);
exit:
/* mark rbios for removal, to be removed by udpdtlsserverrd()*/
SSL_shutdown(ssl);
SSL_free(ssl);
pthread_mutex_lock(&params->sesscache->mutex);
freebios(params->sesscache->rbios);
params->sesscache->rbios = NULL;
gettimeofday(&params->sesscache->expiry, NULL);
params->sesscache->expiry.tv_sec += delay;
pthread_mutex_unlock(&params->sesscache->mutex);
free(params);
ERR_remove_state(0);
pthread_exit(NULL);
debug(DBG_DBG, "dtlsservernew: exiting");
}
void cacheexpire(struct hash *cache, struct timeval *last) {
struct timeval now;
struct hash_entry *he;
struct sessioncacheentry *e;
gettimeofday(&now, NULL);
if (now.tv_sec - last->tv_sec < 19)
return;
for (he = hash_first(cache); he; he = hash_next(he)) {
e = (struct sessioncacheentry *)he->data;
pthread_mutex_lock(&e->mutex);
if (!e->expiry.tv_sec || e->expiry.tv_sec > now.tv_sec) {
pthread_mutex_unlock(&e->mutex);
continue;
}
debug(DBG_DBG, "cacheexpire: freeing entry");
hash_extract(cache, he->key, he->keylen);
if (e->rbios) {
freebios(e->rbios);
e->rbios = NULL;
}
pthread_mutex_unlock(&e->mutex);
pthread_mutex_destroy(&e->mutex);
}
last->tv_sec = now.tv_sec;
}
void *udpdtlsserverrd(void *arg) {
int ndesc, cnt, s = *(int *)arg;
unsigned char buf[4];
struct sockaddr_storage from;
socklen_t fromlen = sizeof(from);
struct dtlsservernewparams *params;
fd_set readfds;
struct timeval timeout, lastexpiry;
pthread_t dtlsserverth;
struct hash *sessioncache;
struct sessioncacheentry *cacheentry;
sessioncache = hash_create();
if (!sessioncache)
debugx(1, DBG_ERR, "udpdtlsserverrd: malloc failed");
gettimeofday(&lastexpiry, NULL);
for (;;) {
FD_ZERO(&readfds);
FD_SET(s, &readfds);
memset(&timeout, 0, sizeof(struct timeval));
timeout.tv_sec = 60;
ndesc = select(s + 1, &readfds, NULL, NULL, &timeout);
if (ndesc < 1) {
cacheexpire(sessioncache, &lastexpiry);
continue;
}
cnt = recvfrom(s, buf, 4, MSG_PEEK | MSG_TRUNC, (struct sockaddr *)&from, &fromlen);
if (cnt == -1) {
debug(DBG_WARN, "udpdtlsserverrd: recv failed");
cacheexpire(sessioncache, &lastexpiry);
continue;
}
cacheentry = hash_read(sessioncache, &from, fromlen);
if (cacheentry) {
debug(DBG_DBG, "udpdtlsserverrd: cache hit");
pthread_mutex_lock(&cacheentry->mutex);
if (cacheentry->rbios) {
if (udp2bio(s, cacheentry->rbios, cnt))
debug(DBG_DBG, "udpdtlsserverrd: got DTLS in UDP from %s", addr2string((struct sockaddr *)&from, fromlen));
} else
recv(s, buf, 1, 0);
pthread_mutex_unlock(&cacheentry->mutex);
cacheexpire(sessioncache, &lastexpiry);
continue;
}
/* from new source */
debug(DBG_DBG, "udpdtlsserverrd: cache miss");
params = malloc(sizeof(struct dtlsservernewparams));
if (!params) {
cacheexpire(sessioncache, &lastexpiry);
recv(s, buf, 1, 0);
continue;
}
memset(params, 0, sizeof(struct dtlsservernewparams));
params->sesscache = malloc(sizeof(struct sessioncacheentry));
if (!params->sesscache) {
free(params);
cacheexpire(sessioncache, &lastexpiry);
recv(s, buf, 1, 0);
continue;
}
memset(params->sesscache, 0, sizeof(struct sessioncacheentry));
pthread_mutex_init(&params->sesscache->mutex, NULL);
params->sesscache->rbios = newqueue();
if (hash_insert(sessioncache, &from, fromlen, params->sesscache)) {
params->sock = s;
memcpy(&params->addr, &from, fromlen);
if (udp2bio(s, params->sesscache->rbios, cnt)) {
debug(DBG_DBG, "udpdtlsserverrd: got DTLS in UDP from %s", addr2string((struct sockaddr *)&from, fromlen));
if (!pthread_create(&dtlsserverth, NULL, dtlsservernew, (void *)params)) {
pthread_detach(dtlsserverth);
cacheexpire(sessioncache, &lastexpiry);
continue;
}
debug(DBG_ERR, "udpdtlsserverrd: pthread_create failed");
}
hash_extract(sessioncache, &from, fromlen);
}
freebios(params->sesscache->rbios);
pthread_mutex_destroy(&params->sesscache->mutex);
free(params->sesscache);
free(params);
cacheexpire(sessioncache, &lastexpiry);
}
}
int dtlsconnect(struct server *server, struct timeval *when, int timeout, char *text) {
......@@ -451,7 +528,9 @@ int clientradputdtls(struct server *server, unsigned char *rad) {
size_t len;
unsigned long error;
struct clsrvconf *conf = server->conf;
if (!server->ssl)
return 0;
len = RADLEN(rad);
while ((cnt = SSL_write(server->ssl, rad, len)) <= 0) {
while ((error = ERR_get_error()))
......@@ -496,19 +575,22 @@ void *dtlsclientrd(void *arg) {
struct server *server = (struct server *)arg;
unsigned char *buf;
struct timeval lastconnecttry;
int secs;
for (;;) {
/* yes, lastconnecttry is really necessary */
lastconnecttry = server->lastconnecttry;
buf = raddtlsget(server->ssl, server->rbios);
for (secs = 0; !(buf = raddtlsget(server->ssl, server->rbios, 10)) && !server->lostrqs && secs < IDLE_TIMEOUT; secs += 10);
if (!buf) {
dtlsconnect(server, &lastconnecttry, 0, "dtlsclientrd");
continue;
}
if (!replyh(server, buf))
free(buf);
}
ERR_remove_state(0);
server->clientrdgone = 1;
return NULL;
}
void addserverextradtls(struct clsrvconf *conf) {
......
......@@ -8,9 +8,7 @@
void *udpdtlsserverrd(void *arg);
int dtlsconnect(struct server *server, struct timeval *when, int timeout, char *text);
void *dtlsservernew(void *arg);
void *dtlsclientrd(void *arg);
void *udpdtlsclientrd(void *arg);
int clientradputdtls(struct server *server, unsigned char *rad);
void addserverextradtls(struct clsrvconf *conf);
void initextradtls();
......@@ -12,12 +12,6 @@
#include "list.h"
#include "hash.h"
struct entry {
void *key;
uint32_t keylen;
void *data;
};
/* allocates and initialises hash structure; returns NULL if malloc fails */
struct hash *hash_create() {
struct hash *h = malloc(sizeof(struct hash));
......@@ -39,8 +33,8 @@ void hash_destroy(struct hash *h) {
if (!h)
return;
for (ln = list_first(h->hashlist); ln; ln = list_next(ln)) {
free(((struct entry *)ln->data)->key);
free(((struct entry *)ln->data)->data);
free(((struct hash_entry *)ln->data)->key);
free(((struct hash_entry *)ln->data)->data);
}
list_destroy(h->hashlist);
pthread_mutex_destroy(&h->mutex);
......@@ -48,13 +42,14 @@ void hash_destroy(struct hash *h) {
/* insert entry in hash; returns 1 if ok, 0 if malloc fails */
int hash_insert(struct hash *h, void *key, uint32_t keylen, void *data) {
struct entry *e;
struct hash_entry *e;
if (!h)
return 0;
e = malloc(sizeof(struct entry));
e = malloc(sizeof(struct hash_entry));
if (!e)
return 0;
memset(e, 0, sizeof(struct hash_entry));
e->key = malloc(keylen);
if (!e->key) {
free(e);
......@@ -77,13 +72,13 @@ int hash_insert(struct hash *h, void *key, uint32_t keylen, void *data) {
/* reads entry from hash */
void *hash_read(struct hash *h, void *key, uint32_t keylen) {
struct list_node *ln;
struct entry *e;
struct hash_entry *e;
if (!h)
return 0;
pthread_mutex_lock(&h->mutex);
for (ln = list_first(h->hashlist); ln; ln = list_next(ln)) {
e = (struct entry *)ln->data;
e = (struct hash_entry *)ln->data;
if (e->keylen == keylen && !memcmp(e->key, key, keylen)) {
pthread_mutex_unlock(&h->mutex);
return e->data;
......@@ -96,13 +91,13 @@ void *hash_read(struct hash *h, void *key, uint32_t keylen) {
/* extracts entry from hash */
void *hash_extract(struct hash *h, void *key, uint32_t keylen) {
struct list_node *ln;
struct entry *e;
struct hash_entry *e;
if (!h)
return 0;
pthread_mutex_lock(&h->mutex);
for (ln = list_first(h->hashlist); ln; ln = list_next(ln)) {
e = (struct entry *)ln->data;
e = (struct hash_entry *)ln->data;
if (e->keylen == keylen && !memcmp(e->key, key, keylen)) {
free(e->key);
list_removedata(h->hashlist, e);
......@@ -114,3 +109,24 @@ void *hash_extract(struct hash *h, void *key, uint32_t keylen) {
pthread_mutex_unlock(&h->mutex);
return NULL;
}
/* returns first entry */
struct hash_entry *hash_first(struct hash *hash) {
struct list_node *ln;
struct hash_entry *e;
if (!hash || !