#include <STDIO.H>
#include <ERRNO.H>
#include <STRING.H>
#include <SPAD/SYNC.H>
#include <OPENSSL/SHA.H>
#include <OPENSSL/PEM.H>
#include <OPENSSL/DSA.H>
#include <OPENSSL/BN.H>
#include <OPENSSL/ERR.H>

#include "SSHD.H"
#include "OPS.H"

#define SSH_RSA_MINIMUM_MODULUS_SIZE	768

union p_key {
	DSA dsa;
	RSA rsa;
};

static union p_key *dss_load_private_from_file(const char *filename, char *err);
static union p_key *dss_load_public_from_blob(const __u8 *blob, unsigned len);
static int dss_put_public_to_blob(union p_key *key, __u8 **blob, unsigned *len);
static int dss_sign(union p_key *key, __u8 **signature, unsigned *signaturelen, const __u8 *data, unsigned datalen);
static int dss_verify(union p_key *key, const __u8 *signature, unsigned signaturelen, const __u8 *data, unsigned datalen);
static void dss_destroy(union p_key *key);

static union p_key *rsa_load_private_from_file(const char *filename, char *err);
static union p_key *rsa_load_public_from_blob(const __u8 *blob, unsigned len);
static int rsa_put_public_to_blob(union p_key *key, __u8 **blob, unsigned *len);
static int rsa_sign(union p_key *key, __u8 **signature, unsigned *signaturelen, const __u8 *data, unsigned datalen);
static int rsa_verify(union p_key *key, const __u8 *signature, unsigned signaturelen, const __u8 *data, unsigned datalen);
static void rsa_destroy(union p_key *key);

static const struct public_type ssh_dss = {
	"ssh-dss",
	dss_load_private_from_file,
	dss_load_public_from_blob,
	dss_put_public_to_blob,
	dss_sign,
	dss_verify,
	dss_destroy,
};

static const struct public_type ssh_rsa = {
	"ssh-rsa",
	rsa_load_private_from_file,
	rsa_load_public_from_blob,
	rsa_put_public_to_blob,
	rsa_sign,
	rsa_verify,
	rsa_destroy,
};

const struct public_type * const public_types[] = {
	&ssh_rsa,
	&ssh_dss,
	NULL
};

/* Generic functions */

const struct public_type *get_public_type(const char *name)
{
	int i;
	for (i = 0; public_types[i]; i++)
		if (!strcmp(public_types[i]->name, name))
			return public_types[i];
	return NULL;
}

static int check_name_from_str(const __u8 *str, unsigned len, unsigned *ptr, const char *want)
{
	size_t l;
	const __u8 *name;
	unsigned namelen;
	if (__unlikely(get_len_str_from_str(str, len, ptr, &name, &namelen)))
		return -1;
	l = strlen(want);
	return __unlikely(namelen != l) || __unlikely(memcmp(name, want, l));
}

static union p_key *key_load_private_from_file(const char *filename, int type, char *err)
{
	FILE *fp = fopen(filename, "r");
	EVP_PKEY *pk;
	void *data;
	if (__unlikely(!fp)) {
		*err = 0;
		/*_snprintf(err, __MAX_STR_LEN, "CAN'T OPEN %s: %s", filename, strerror(errno));*/
		return NULL;
	}
	pk = PEM_read_PrivateKey(fp, NULL, NULL, "");
	if (__unlikely(!pk)) {
		_snprintf(err, __MAX_STR_LEN, "INVALID KEY FORMAT: %lX", ERR_get_error());
		fclose(fp);
		errno = EINVAL;
		return NULL;
	}
	if (__unlikely(pk->type != type)) {
		_snprintf(err, __MAX_STR_LEN, "INVALID KEY TYPE");
		EVP_PKEY_free(pk);
		fclose(fp);
		errno = EINVAL;
		return NULL;
	}
	if (__likely(pk->type == EVP_PKEY_DSA)) {
		data = EVP_PKEY_get1_DSA(pk);
	} else if (__likely(pk->type == EVP_PKEY_RSA)) {
		data = EVP_PKEY_get1_RSA(pk);
	} else {
		KERNEL$SUICIDE("INVALID TYPE %d REQUESTED", pk->type);
	}
	if (__unlikely(!data)) {
		_snprintf(err, __MAX_STR_LEN, "COULD NOT GET KEY DATA");
		EVP_PKEY_free(pk);
		fclose(fp);
		errno = EINVAL;
		return NULL;
	}
	EVP_PKEY_free(pk);
	fclose(fp);
	return data;
}

/* DSA */

static union p_key *dss_load_private_from_file(const char *filename, char *err)
{
	return key_load_private_from_file(filename, EVP_PKEY_DSA, err);
}

static union p_key *dss_load_public_from_blob(const __u8 *blob, unsigned len)
{
	unsigned sp;

	union p_key *key = (union p_key *)DSA_new();
	if (__unlikely(!key)) {
		errno = ENOMEM;
		return NULL;
	}
	if (__unlikely(!(key->dsa.p = BN_new())) ||
	    __unlikely(!(key->dsa.q = BN_new())) ||
	    __unlikely(!(key->dsa.g = BN_new())) ||
	    __unlikely(!(key->dsa.pub_key = BN_new()))) {
		errno = ENOMEM;
		goto free_ret;
	}
	sp = 0;
	if (__unlikely(check_name_from_str(blob, len, &sp, ssh_dss.name)) ||
	    __unlikely(get_bn_from_str(blob, len, &sp, key->dsa.p)) ||
	    __unlikely(get_bn_from_str(blob, len, &sp, key->dsa.q)) ||
	    __unlikely(get_bn_from_str(blob, len, &sp, key->dsa.g)) ||
	    __unlikely(get_bn_from_str(blob, len, &sp, key->dsa.pub_key))) {
		errno = EINVAL;
		goto free_ret;
	}

	return key;

	free_ret:
	dss_destroy(key);
	return NULL;
}

static int dss_put_public_to_blob(union p_key *key, __u8 **blob, unsigned *len)
{
	init_str(blob, len);
	if (__unlikely(add_len_str_to_str(blob, len, (__u8 *)ssh_dss.name, strlen(ssh_dss.name)))) return -1;
	if (__unlikely(add_bn_to_str(blob, len, key->dsa.p))) return -1;
	if (__unlikely(add_bn_to_str(blob, len, key->dsa.q))) return -1;
	if (__unlikely(add_bn_to_str(blob, len, key->dsa.g))) return -1;
	if (__unlikely(add_bn_to_str(blob, len, key->dsa.pub_key))) return -1;
	return 0;
}

#define INTBLOB_LEN	20
#define SIGBLOB_LEN	(INTBLOB_LEN * 2)

static int dss_sign(union p_key *key, __u8 **signature, unsigned *signaturelen, const __u8 *data, unsigned datalen)
{
	__u8 digest[SHA_DIGEST_LENGTH], sigblob[SIGBLOB_LEN];
	DSA_SIG *sig;
	unsigned rlen, slen;
	SHA1(data, datalen, digest);
	sig = DSA_do_sign(digest, sizeof digest, &key->dsa);
	if (__unlikely(!sig)) return -1;
	rlen = BN_num_bytes(sig->r);
	slen = BN_num_bytes(sig->s);
	if (__unlikely(rlen > INTBLOB_LEN) || __unlikely(slen > INTBLOB_LEN)) {
		free_ret_einval:
		DSA_SIG_free(sig);
		return -EINVAL;
	}
	memset(sigblob, 0, sizeof sigblob);
	if (__unlikely(!BN_bn2bin(sig->r, sigblob + SIGBLOB_LEN - INTBLOB_LEN - rlen))) goto free_ret_einval;
	if (__unlikely(!BN_bn2bin(sig->s, sigblob + SIGBLOB_LEN - slen))) goto free_ret_einval;
	DSA_SIG_free(sig);

		/* !!! SSH_BUG_SIGBLOB */

	init_str(signature, signaturelen);
	if (__unlikely(add_len_str_to_str(signature, signaturelen, (__u8 *)ssh_dss.name, strlen(ssh_dss.name)))) return -ENOMEM;
	if (__unlikely(add_len_str_to_str(signature, signaturelen, sigblob, SIGBLOB_LEN))) return -ENOMEM;
	return 0;
}

static int dss_verify(union p_key *key, const __u8 *signature, unsigned signaturelen, const __u8 *data, unsigned datalen)
{
	int r;
	unsigned sp = 0;
	const __u8 *blob;
	unsigned bloblen;
	DSA_SIG *sig;
	__u8 digest[SHA_DIGEST_LENGTH];

	SHA1(data, datalen, digest);

		/* !!! SSH_BUG_SIGBLOB */

	if (__unlikely(check_name_from_str(signature, signaturelen, &sp, ssh_dss.name))) return -EINVAL;

	if (__unlikely(get_len_str_from_str(signature, signaturelen, &sp, &blob, &bloblen))) return -EINVAL;

	if (__unlikely(sp != signaturelen)) return -EINVAL;

	if (__unlikely(bloblen != SIGBLOB_LEN)) return -EINVAL;

	sig = DSA_SIG_new();
	if (__unlikely(!sig)) return -ENOMEM;
	if (__unlikely(!(sig->r = BN_new()))) {
		r = -ENOMEM;
		goto free_sig_ret;
	}
	if (__unlikely(!(sig->s = BN_new()))) {
		r = -ENOMEM;
		goto free_sig_ret;
	}

	if (__unlikely(!BN_bin2bn(blob, INTBLOB_LEN, sig->r))) {
		r = -EINVAL;
		goto free_sig_ret;
	}
	if (__unlikely(!BN_bin2bn(blob + INTBLOB_LEN, INTBLOB_LEN, sig->s))) {
		r = -EINVAL;
		goto free_sig_ret;
	}

	r = DSA_do_verify(digest, sizeof digest, sig, &key->dsa);
	if (__likely(r == 1))
		r = 0;
	else if (!r)
		r = 1;
	else
		r = -EINVAL;

	free_sig_ret:
	DSA_SIG_free(sig);
	return r;
}

static void dss_destroy(union p_key *key)
{
	DSA_free(&key->dsa);
}

/* RSA */

static union p_key *rsa_load_private_from_file(const char *filename, char *err)
{
	return key_load_private_from_file(filename, EVP_PKEY_RSA, err);
}

static union p_key *rsa_load_public_from_blob(const __u8 *blob, unsigned len)
{
	unsigned sp;

	union p_key *key = (union p_key *)RSA_new();
	if (__unlikely(!key)) {
		errno = ENOMEM;
		return NULL;
	}
	if (__unlikely(!(key->rsa.n = BN_new())) ||
	    __unlikely(!(key->rsa.e = BN_new()))) {
		errno = ENOMEM;
		goto free_ret;
	}
	sp = 0;
	if (__unlikely(check_name_from_str(blob, len, &sp, "ssh-rsa")) ||
	    __unlikely(get_bn_from_str(blob, len, &sp, key->rsa.e)) ||
	    __unlikely(get_bn_from_str(blob, len, &sp, key->rsa.n))) {
		errno = EINVAL;
		goto free_ret;
	}

	return key;

	free_ret:
	rsa_destroy(key);
	return NULL;
}

static int rsa_put_public_to_blob(union p_key *key, __u8 **blob, unsigned *len)
{
	init_str(blob, len);
	if (__unlikely(add_len_str_to_str(blob, len, (__u8 *)ssh_rsa.name, strlen(ssh_rsa.name)))) return -1;
	if (__unlikely(add_bn_to_str(blob, len, key->rsa.e))) return -1;
	if (__unlikely(add_bn_to_str(blob, len, key->rsa.n))) return -1;
	return 0;
}

static int rsa_sign(union p_key *key, __u8 **signature, unsigned *signaturelen, const __u8 *data, unsigned datalen)
{
	__u8 digest[SHA_DIGEST_LENGTH];
	unsigned len, slen;
	__u8 *sig;

	/* SSH_BUG_RSASIGMD5 */
	SHA1(data, datalen, digest);

	slen = RSA_size(&key->rsa);
	sig = malloc(slen);
	if (__unlikely(!sig)) return -ENOMEM;
	/* SSH_BUG_RSASIGMD5 */
	if (RSA_sign(NID_sha1, digest, sizeof digest, sig, &len, &key->rsa) != 1)
		goto free_ret_einval;

	if (__unlikely(len != slen)) {
		if (__unlikely(len > slen))
			KERNEL$SUICIDE("rsa_sign: RSA_sign shot out of memory, %u > %u", len, slen);
		memmove(sig + (slen - len), sig, len);
		memset(sig, 0, slen - len);
	}

	init_str(signature, signaturelen);
	if (__unlikely(add_len_str_to_str(signature, signaturelen, (__u8 *)ssh_rsa.name, strlen(ssh_rsa.name)))) goto free_ret_enomem;
	if (__unlikely(add_len_str_to_str(signature, signaturelen, sig, slen))) goto free_ret_enomem;
	free(sig);
	return 0;

	free_ret_enomem:
	free(sig);
	return -ENOMEM;

	free_ret_einval:
	free(sig);
	return -EINVAL;
}

static int rsa_verify(union p_key *key, const __u8 *signature, unsigned signaturelen, const __u8 *data, unsigned datalen)
{
	int r;
	unsigned sp = 0;
	const __u8 *blob;
	unsigned bloblen;
	__u8 *b = NULL;
	unsigned slen;
	__u8 digest[SHA_DIGEST_LENGTH];

	if (__unlikely(BN_num_bits(key->rsa.n) < SSH_RSA_MINIMUM_MODULUS_SIZE))
		return -EINVAL;

	/* SSH_BUG_RSASIGMD5 */
	SHA1(data, datalen, digest);

	if (__unlikely(check_name_from_str(signature, signaturelen, &sp, ssh_rsa.name))) return -EINVAL;

	if (__unlikely(get_len_str_from_str(signature, signaturelen, &sp, &blob, &bloblen))) return -EINVAL;

	if (__unlikely(sp != signaturelen)) return -EINVAL;

	slen = RSA_size(&key->rsa);
	if (__unlikely(bloblen != slen)) {
		if (bloblen > slen) return 1;
		b = malloc(slen);
		if (__unlikely(!b)) return -ENOMEM;
		memset(b, 0, slen - bloblen);
		memcpy(b + (slen - bloblen), blob, bloblen);
		blob = b;
		bloblen = slen;
	}
	/* SSH_BUG_RSASIGMD5 */
	r = RSA_verify(NID_sha1, digest, sizeof digest, (__u8 *)blob, bloblen, &key->rsa);
	if (__likely(r == 1))
		r = 0;
	else if (!r)
		r = 1;
	else
		r = -EINVAL;

	free(b);
	return r;
}

static void rsa_destroy(union p_key *key)
{
	RSA_free(&key->rsa);
}
