#include <SPAD/LIBC.H>
#include <SPAD/DEV.H>
#include <SPAD/SYNC.H>
#include <SPAD/SYSLOG.H>

#include <OPENSSL/SSL.H>
#include <OPENSSL/BIO.H>
#include <OPENSSL/X509V3.H>

#include "EAP.H"

#include "ROOTER.H"

#include "EAPTLS.H"

struct tls_private {
	__u8 version;
	unsigned fragment_size;
	SSL_CTX *ctx;
	SSL *ssl;
	STACK_OF(X509) *cert_chain;
	BIO *bio_in;
	BIO *bio_out;
	__u8 *raw_data_out;
	unsigned raw_data_out_len;
	__u8 *data_out;
	unsigned data_out_len;
};

#define tls(inst)	((struct tls_private *)(inst)->private)

static int initialize_tls(struct instance *inst);
static void cleanup_tls(struct instance *inst);
static int ttls_get_config_with_certificate(struct instance *inst, struct config *c);
static int ttls_get_config_without_certificate(struct instance *inst, struct config *c);
static int tls_flush_data(struct instance *inst);
static int tls_prepare_packet(struct instance *inst);
static int tls_send_packet(struct instance *inst);
static int tls_send_ack(struct instance *inst);
static int tls_parse_packet(struct instance *inst, EAP_PACKET_REQUEST *packet, __u8 **data, unsigned *data_size, unsigned *total_size);
static int tls_read_packet(struct instance *inst);
static void ttls_avp(struct instance *inst, __u32 code, __u32 vendor, __u8 flags, __u8 *data, unsigned len);

int eap_tls(struct instance *inst)
{
	int return_value = AUTH_PROTOCOL_VIOLATION;
	int res;
	struct config *cfg;
	inst->private = xmalloc(sizeof(struct tls_private));
	memset(inst->private, 0, sizeof(struct tls_private));
	tls(inst)->fragment_size = TLS_FRAGMENT_SIZE;	/* when modifying, make sure that overflow does not happen wrt. EAP_MAX_DATA_LENGTH */
	if (__unlikely(get_packet(inst, &inst->packet))) goto ret;
	if (__unlikely(!inst->packet.data_length)) {
		if (errorlevel >= 1)
			KERNEL$SYSLOG(__SYSLOG_NET_ERROR, inst->name, "TLS: NO DATA IN INITIAL PACKET");
		goto ret;
	}
	if (__unlikely(!(inst->packet.data[0] & EAP_TLS_S))) {
		if (errorlevel >= 2)
			KERNEL$SYSLOG(__SYSLOG_NET_WARNING, inst->name, "TLS: NOT A START PACKET, FLAGS %02X", inst->packet.data[0]);
		goto ret;
	}
	tls(inst)->version = inst->packet.data[0] & EAP_TLS_VERSION;
	if (__unlikely(tls(inst)->version != 0)) {
		if (errorlevel >= 2)
			KERNEL$SYSLOG(__SYSLOG_NET_WARNING, inst->name, "TLS: UNSUPPORTED TLS VERSION %02X", tls(inst)->version);
		goto ret;
	}
	if (__unlikely(initialize_tls(inst))) goto ret;

	try_connect_again:
	lock_ssl();
	res = SSL_connect(tls(inst)->ssl);
	if (res != 1) {
		int err = SSL_get_error(tls(inst)->ssl, res);
		if (err == SSL_ERROR_WANT_READ) {
			unlock_ssl();
			debug(("want read: prepare\n"));
			if (__unlikely(tls_prepare_packet(inst))) goto ret;
			if (__unlikely(tls_send_packet(inst))) goto ret;
			if (__unlikely(tls_read_packet(inst))) goto ret;
			goto try_connect_again;
		} else if (err == SSL_ERROR_WANT_WRITE) {
			KERNEL$SYSLOG(__SYSLOG_SW_ERROR, inst->name, "TLS: CAN'T WRITE TO MEMORY BIO");
			print_ssl_errors();
			unlock_ssl();
			goto ret;
		} else {
			if (errorlevel >= 1) {
				KERNEL$SYSLOG(__SYSLOG_NET_ERROR, inst->name, "TLS: SSL ERROR %d", err);
				print_ssl_errors();
			}
			clear_ssl_errors();
			unlock_ssl();
			goto ret;
		}
	}
	unlock_ssl();
	if (__unlikely(tls_prepare_packet(inst))) goto ret;
	/* do not send it yet because some new encrypted data might need to be added */
	/*if (__unlikely(tls_send_packet(inst))) goto ret;*/

	lock_ssl();
	tls(inst)->cert_chain = SSL_get_peer_cert_chain(tls(inst)->ssl);
	debug(("cert chain: %p\n", tls(inst)->cert_chain));
	cfg = get_config(inst, ttls_get_config_with_certificate);
	if (!cfg) cfg = get_config(inst, ttls_get_config_without_certificate);
	if (__unlikely(!cfg)) {
		unlock_ssl();
		if (errorlevel >= 2)
			KERNEL$SYSLOG(__SYSLOG_NET_WARNING, inst->name, "TLS: NO CONFIG ENTRY FOR CONNECTION");
		goto ret;
	}
	if (!cfg->inner_algorithm) {
		if (!cfg->hash) goto pap;
		else goto chap;
	}
	if (!_strcasecmp(cfg->inner_algorithm, "pap")) {
		__u8 *pwd;
		unsigned pwdlen, pwdlenpadded;
		pap:
		if (__unlikely(!cfg->user) || __unlikely(!cfg->password)) {
			KERNEL$SYSLOG(__SYSLOG_CONF_ERROR, inst->name, "TLS: USERNAME OR PASSWORD NOT SPECIFIED FOR PAP");
			goto ret;
		}
		ttls_avp(inst, AVP_USER, 0, AVP_FLAG_MANDATORY, (__u8 *)cfg->user, strlen(cfg->user));
		pwdlen = strlen(cfg->password);
		pwdlenpadded = (pwdlen + 15) & ~15;
		pwd = xmalloc(pwdlenpadded);
		memcpy(pwd, cfg->password, pwdlen);
		memset(pwd + pwdlen, 0, pwdlenpadded - pwdlen);
		ttls_avp(inst, AVP_PASSWORD, 0, AVP_FLAG_MANDATORY, pwd, pwdlenpadded);
		free(pwd);
		debug(("sending pap: \"%s\", \"%s\"\n", cfg->user, cfg->password));
		unlock_ssl();
		if (__unlikely(tls_flush_data(inst))) goto ret;
		debug(("sent pap\n"));
		return_value = AUTH_DONT_KNOW;
		goto ret;
	} else if (!_strcasecmp(cfg->inner_algorithm, "chap")) {
		chap:
		unlock_ssl();
		KERNEL$SYSLOG(__SYSLOG_SW_INCOMPATIBILITY, inst->name, "TLS: CHAP NOT YET SUPPORTED");
		goto ret;
	} else {
		KERNEL$SUICIDE("eap_ttls: UNKNOWN INNER ALGORITHM %s", cfg->inner_algorithm);
	}
	unlock_ssl();

	ret:
	cleanup_tls(inst);
	free(inst->private);
	return return_value;
}

static int initialize_tls(struct instance *inst)
{
	lock_ssl();
	new_ssl_ctx(&tls(inst)->ctx);
	if (__unlikely(!tls(inst)->ctx)) {
		KERNEL$SYSLOG(__SYSLOG_SW_ERROR, inst->name, "TLS: CAN'T CREATE SSL CONTEXT");
		print_ssl_errors();
		unlock_ssl();
		return -1;
	}
	tls(inst)->ssl = SSL_new(tls(inst)->ctx);
	if (__unlikely(!tls(inst)->ssl)) {
		KERNEL$SYSLOG(__SYSLOG_SW_ERROR, inst->name, "TLS: CAN'T CREATE SSL CONNECTION");
		print_ssl_errors();
		unlock_ssl();
		return -1;
	}
	SSL_set_app_data(tls(inst)->ssl, inst);
	SSL_set_options(tls(inst)->ssl, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_SINGLE_DH_USE);
	tls(inst)->bio_in = BIO_new(BIO_s_mem());
	if (__unlikely(!tls(inst)->bio_in)) {
		KERNEL$SYSLOG(__SYSLOG_SW_ERROR, inst->name, "TLS: CAN'T CREATE INPUT BIO");
		print_ssl_errors();
		unlock_ssl();
		return -1;
	}
	tls(inst)->bio_out = BIO_new(BIO_s_mem());
	if (__unlikely(!tls(inst)->bio_out)) {
		KERNEL$SYSLOG(__SYSLOG_SW_ERROR, inst->name, "TLS: CAN'T CREATE OUTPUT BIO");
		print_ssl_errors();
		BIO_free(tls(inst)->bio_in);
		unlock_ssl();
		return -1;
	}
	SSL_set_bio(tls(inst)->ssl, tls(inst)->bio_in, tls(inst)->bio_out);
	unlock_ssl();
	return 0;
}

static void cleanup_tls(struct instance *inst)
{
	lock_ssl();
	if (tls(inst)->ssl) SSL_free(tls(inst)->ssl);
	if (tls(inst)->ctx) SSL_CTX_free(tls(inst)->ctx);
	free(tls(inst)->raw_data_out);
	free(tls(inst)->data_out);
	unlock_ssl();
}

static int is_supported_inner_algorithm(char *name)
{
	if (!_strcasecmp(name, "pap")) return 1;
	return 0;
}

static int ttls_get_config_common(struct instance *inst, struct config *c)
{
	if (c->inner_algorithm && !is_supported_inner_algorithm(c->inner_algorithm)) return 0;
	if (!c->user || !c->password) return 0;
	if (c->cncheck) {
		char name[__MAX_STR_LEN];
		X509 *cert = SSL_get_peer_certificate(tls(inst)->ssl);
		if (__unlikely(!cert)) return 0;
		if (__unlikely(X509_NAME_get_text_by_NID(X509_get_subject_name(cert), NID_commonName, name, sizeof name) < 0)) {
			X509_free(cert);
			return 0;
		}
		X509_free(cert);
		debug(("cert name: \"%s\"\n", name));
		if (c->cnexact) {
			debug(("exact(%s)\n", c->cncheck));
			if (__unlikely(strcmp(name, c->cncheck))) return 0;
		} else {
			int d = strlen(name) - strlen(c->cncheck);
			debug(("inexact(%s)\n", c->cncheck));
			if (d < 0) return 0;
			if (__unlikely(strcmp(name + d, c->cncheck))) return 0;
			if (d > 0) if (name[d - 1] != '.') return 0;
		}
	}
	return 1;
}

static int ttls_get_config_with_certificate(struct instance *inst, struct config *c)
{
	int val;
	X509_STORE_CTX *tmp_ctx;
	if (__unlikely(!ttls_get_config_common(inst, c))) return 0;
	if (__unlikely(!c->cert_store)) return 0;
	if (__unlikely(!tls(inst)->cert_chain) || __unlikely(!sk_X509_num(tls(inst)->cert_chain))) return 0;
	/*if (__unlikely(!X509_NAME_oneline(X509_get_subject_name(sk_X509_value(tls(inst)->cert_chain, 0)), buf, sizeof(buf))))*/
	tmp_ctx = X509_STORE_CTX_new();
	if (__unlikely(!tmp_ctx)) {
		KERNEL$SYSLOG(__SYSLOG_SW_ERROR, inst->name, "CAN'T GET X509 STORE CONTEXT");
		print_ssl_errors();
		return 0;
	}
	if (__unlikely(!X509_STORE_CTX_init(tmp_ctx, c->cert_store, sk_X509_value(tls(inst)->cert_chain, 0), tls(inst)->cert_chain))) {
		KERNEL$SYSLOG(__SYSLOG_SW_ERROR, inst->name, "CAN'T INITIALIZE X509 STORE CONTEXT");
		print_ssl_errors();
		val = 0;
		goto ret;
	}
	X509_STORE_CTX_set_purpose(tmp_ctx, X509_PURPOSE_SSL_SERVER);
	if (X509_verify_cert(tmp_ctx) == 1) val = 1;
	else val = 0;
	debug(("cert verify: %d\n", val));
	ret:
	X509_STORE_CTX_free(tmp_ctx);
	debug(("return: %d\n", val));
	return val;
}

static int ttls_get_config_without_certificate(struct instance *inst, struct config *c)
{
	return ttls_get_config_common(inst, c) && !c->cert_store;
}

static int tls_prepare_packet(struct instance *inst)
{
	unsigned len;
	lock_ssl();
	len = BIO_ctrl_pending(tls(inst)->bio_out);
	tls(inst)->raw_data_out = xrealloc(tls(inst)->raw_data_out, tls(inst)->raw_data_out_len + len);
	if (__unlikely(len)) {
		int r = BIO_read(tls(inst)->bio_out, tls(inst)->raw_data_out + tls(inst)->raw_data_out_len, len);
		if (__unlikely(r != len)) {
			KERNEL$SYSLOG(__SYSLOG_SW_ERROR, inst->name, "TLS: BIO READ RETURNED %d", r);
			print_ssl_errors();
			unlock_ssl();
			return -EIO;
		}
	}
	tls(inst)->raw_data_out_len += len;
	unlock_ssl();
	return 0;
}

static int tls_send_packet(struct instance *inst)
{
	unsigned off = 0;
	int need_len = tls(inst)->fragment_size < tls(inst)->raw_data_out_len;
	while (1) {
		__u8 *data;
		unsigned data_size, total_size;
		int r;
		unsigned x;
		unsigned size_to_send = off + tls(inst)->fragment_size < tls(inst)->raw_data_out_len ? tls(inst)->fragment_size : tls(inst)->raw_data_out_len - off;
		inst->packet.data[0] = tls(inst)->version;
		if (off + size_to_send < tls(inst)->raw_data_out_len) inst->packet.data[0] |= EAP_TLS_M;
		x = 1;
		if (need_len) {
			inst->packet.data[0] |= EAP_TLS_L;
			inst->packet.data[x++] = tls(inst)->raw_data_out_len >> 24;
			inst->packet.data[x++] = tls(inst)->raw_data_out_len >> 16;
			inst->packet.data[x++] = tls(inst)->raw_data_out_len >> 8;
			inst->packet.data[x++] = tls(inst)->raw_data_out_len;
		}
		if (__unlikely(x + size_to_send > EAP_MAX_DATA_LENGTH))
			KERNEL$SUICIDE("tls_send_packet: PACKET OVERFLOW: %u + %u > %u", x, size_to_send, EAP_MAX_DATA_LENGTH);
		memcpy(&inst->packet.data[x], tls(inst)->raw_data_out + off, size_to_send);
		inst->packet.data_length = x + size_to_send;
		memset(&inst->packet.data[inst->packet.data_length], 0, sizeof inst->packet.data - inst->packet.data_length);
		if (__unlikely(r = send_packet(inst, &inst->packet))) return r;
		off += size_to_send;
		if (off >= tls(inst)->raw_data_out_len) break;
		if (__unlikely(r = get_packet(inst, &inst->packet))) return r;
		if (__unlikely(r = tls_parse_packet(inst, &inst->packet, &data, &data_size, &total_size))) return r;
		if (__unlikely(data_size != 0) || __unlikely(total_size != 0)) {
			if (errorlevel >= 1)
				KERNEL$SYSLOG(__SYSLOG_NET_ERROR, inst->name, "TLS: INVALID ACK PACKET, FLAGS %02X, DATA SIZE %u, TOTAL SIZE %u", inst->packet.data[0], data_size, total_size);
			return -EPROTO;
		}
	}
	tls(inst)->raw_data_out_len = 0;
	return 0;
}

static int tls_parse_packet(struct instance *inst, EAP_PACKET_REQUEST *packet, __u8 **data, unsigned *data_size, unsigned *total_size)
{
	unsigned ds;
	unsigned x;
	unsigned ts = 0;	/* supress warning */
	if (__unlikely(!packet->data_length)) {
		if (errorlevel >= 1)
			KERNEL$SYSLOG(__SYSLOG_NET_ERROR, inst->name, "TLS: ZERO-SIZE PACKET");
		return -EPROTO;
	}
	if (__unlikely((packet->data[0] & EAP_TLS_VERSION) != tls(inst)->version)) {
		if (errorlevel >= 1)
			KERNEL$SYSLOG(__SYSLOG_NET_ERROR, inst->name, "TLS: VERSION MISMATCH, RECEIVED %02X, WANTED %02X", packet->data[0] & EAP_TLS_VERSION, tls(inst)->version);
		return -EPROTO;
	}
	if (__unlikely(packet->data[0] & EAP_TLS_S)) {
		if (errorlevel >= 1)
			KERNEL$SYSLOG(__SYSLOG_NET_ERROR, inst->name, "TLS: START PACKET, FLAGS %02X", packet->data[0]);
		return -EPROTO;
	}
	x = 1;
	if (packet->data[0] & EAP_TLS_L) {
		if (__unlikely(x + 4 > packet->data_length)) {
			too_short:
			if (errorlevel >= 1)
				KERNEL$SYSLOG(__SYSLOG_NET_ERROR, inst->name, "TLS: TOO SHORT PACKET, FLAGS %02X, LENGTH %u", packet->data[0], packet->data_length);
			return -EPROTO;
		}
		if (__unlikely((ts = (packet->data[x] << 24) | (packet->data[x + 1] << 16) | (packet->data[x + 2] << 8) | packet->data[x + 3]) > TLS_MAX_TOTAL_SIZE)) {
			if (errorlevel >= 1)
				KERNEL$SYSLOG(__SYSLOG_NET_ERROR, inst->name, "TLS: MAXIMUM TOTAL SIZE EXCEEDED: %u > %u", ts, TLS_MAX_TOTAL_SIZE);
			return -EPROTO;
		}
		x += 4;
	}
	if (__unlikely(packet->data[0] & EAP_TLS_T)) {
		if (__unlikely(x + 4 > packet->data_length)) goto too_short;
		ds = (packet->data[x] << 24) | (packet->data[x + 1] << 16) | (packet->data[x + 2] << 8) | packet->data[x + 3];
		x += 4;
	} else {
		ds = packet->data_length - x;
	}
	if (__unlikely(ds > packet->data_length - x)) {
		if (errorlevel >= 1)
			KERNEL$SYSLOG(__SYSLOG_NET_ERROR, inst->name, "TLS: MESSAGE LENGTH FIELD LARGER THAN PACKET DATA LENGTH: %u > %u", ds, packet->data_length - x);
		return -EPROTO;
	}
	if (!(packet->data[0] & EAP_TLS_L)) {
		ts = ds;
	}
	if (__unlikely(ts < ds)) {
		if (errorlevel >= 1)
			KERNEL$SYSLOG(__SYSLOG_NET_ERROR, inst->name, "TLS: TOTAL SIZE LESS THAN FRAGMENT SIZE: %u < %u", ts, ds);
		return -EPROTO;
	}
	*data = &packet->data[x];
	*data_size = ds;
	*total_size = ts;
	x += ds;
	if (__unlikely(x < packet->data_length)) {
		/* outer TLVs ... ignored */
	}
	return 0;
}

static int tls_read_packet(struct instance *inst)
{
	__u8 *data;
	unsigned data_size, total_size;
	unsigned remaining;
	int r;
	if (__unlikely(r = get_packet(inst, &inst->packet))) return r;
	if (__unlikely(r = tls_parse_packet(inst, &inst->packet, &data, &data_size, &total_size))) return r;
	remaining = total_size - data_size;
	if (data_size) {
		put_data:
		lock_ssl();
		r = BIO_write(tls(inst)->bio_in, data, data_size);
		if (__unlikely(r != data_size)) {
			KERNEL$SYSLOG(__SYSLOG_SW_ERROR, inst->name, "TLS: BIO WRITE RETURNED %d", r);
			print_ssl_errors();
			unlock_ssl();
			return -EIO;
		}
		unlock_ssl();
	}
	if (remaining) {
		unsigned new_total_size;
		if (__unlikely(!(inst->packet.data[0] & EAP_TLS_M))) {
			if (errorlevel >= 1)
				KERNEL$SYSLOG(__SYSLOG_NET_ERROR, inst->name, "TLS: PACKET DOES NOT HAVE 'M' BIT SET BUT MORE PACKETS SHOULD FOLLOW");
			return -EPROTO;
		}

		if (__unlikely(r = tls_send_ack(inst))) return r;

		if (__unlikely(r = get_packet(inst, &inst->packet))) return r;
		if (__unlikely(r = tls_parse_packet(inst, &inst->packet, &data, &data_size, &new_total_size))) return r;
		if (inst->packet.data[0] & EAP_TLS_L && __unlikely(new_total_size != total_size)) {
			if (errorlevel >= 1)
				KERNEL$SYSLOG(__SYSLOG_NET_ERROR, inst->name, "TLS: PACKET HAS DIFFERENT TOTAL SIZE THAN PREVIOUS: %u != %u", new_total_size, total_size);
			return -EPROTO;
		}
		if (__unlikely(!data_size)) {
			if (errorlevel >= 1)
				KERNEL$SYSLOG(__SYSLOG_NET_ERROR, inst->name, "TLS: ZERO-SIZED FRAGMENT PACKET");
			return -EPROTO;
		}
		if (__unlikely(data_size > remaining)) {
			if (errorlevel >= 1)
				KERNEL$SYSLOG(__SYSLOG_NET_ERROR, inst->name, "TLS: FRAGMENT EXCEEDS ORIGINAL MESSAGE SIZE, SIZE %u, POSITION %u, TOTAL SIZE %u", data_size, total_size - remaining, total_size);
			return -EPROTO;
		}
		remaining -= data_size;
		goto put_data;
	}
	if (__unlikely(inst->packet.data[0] & EAP_TLS_M)) {
		if (errorlevel >= 1)
			KERNEL$SYSLOG(__SYSLOG_NET_ERROR, inst->name, "TLS: 'M' BIT SET IN LAST FRAGMENT");
		return -EPROTO;
	}
	return 0;
}

static int tls_send_ack(struct instance *inst)
{
	inst->packet.data_length = 1;
	memset(&inst->packet.data, 0, sizeof inst->packet.data);
	inst->packet.data[0] = tls(inst)->version;
	return send_packet(inst, &inst->packet);
}

static int tls_flush_data(struct instance *inst)
{
	int r;
	if (__unlikely(!tls(inst)->data_out_len)) {
		if (!tls(inst)->raw_data_out_len) return tls_send_ack(inst);
		return tls_send_packet(inst);
	}
	lock_ssl();
	r = SSL_write(tls(inst)->ssl, tls(inst)->data_out, tls(inst)->data_out_len);
	if (__unlikely(r != tls(inst)->data_out_len)) {
		KERNEL$SYSLOG(__SYSLOG_SW_ERROR, inst->name, "TLS: SSL WRITE RETURNED %d", r);
		print_ssl_errors();
		unlock_ssl();
		return -EIO;
	}
	tls(inst)->data_out_len = 0;
	unlock_ssl();
	r = tls_prepare_packet(inst);
	if (__unlikely(r)) return r;
	return tls_send_packet(inst);
}

static void ttls_avp(struct instance *inst, __u32 code, __u32 vendor, __u8 flags, __u8 *data, unsigned len)
{
	unsigned total_len = 8 + len;
	unsigned padded_pos;
	unsigned pos = tls(inst)->data_out_len;
	if (__unlikely(vendor != 0)) flags |= AVP_FLAG_VENDOR, total_len += 4;
	if (__unlikely(total_len >= 1 << 24)) return;
	padded_pos = pos + ((total_len + 3) & ~3);
	tls(inst)->data_out = xrealloc(tls(inst)->data_out, padded_pos);
	tls(inst)->data_out[pos++] = code >> 24;
	tls(inst)->data_out[pos++] = code >> 16;
	tls(inst)->data_out[pos++] = code >> 8;
	tls(inst)->data_out[pos++] = code;
	tls(inst)->data_out[pos++] = flags;
	tls(inst)->data_out[pos++] = total_len >> 16;
	tls(inst)->data_out[pos++] = total_len >> 8;
	tls(inst)->data_out[pos++] = total_len;
	if (__unlikely(flags & AVP_FLAG_VENDOR)) {
		tls(inst)->data_out[pos++] = vendor >> 24;
		tls(inst)->data_out[pos++] = vendor >> 16;
		tls(inst)->data_out[pos++] = vendor >> 8;
		tls(inst)->data_out[pos++] = vendor;
	}
	memcpy(&tls(inst)->data_out[pos], data, len);
	pos += len;
	while (pos < padded_pos) tls(inst)->data_out[pos++] = 0;
	tls(inst)->data_out_len = padded_pos;
}
