#include <STDLIB.H>
#include <SPAD/LIBC.H>
#include <SPAD/SYNC.H>
#include <OPENSSL/BN.H>
#include <OPENSSL/DH.H>
#include <OPENSSL/SHA.H>

#include "STRUCT.H"
#include "SSHD.H"
#include "OPS.H"

static void kexinit_received(struct ssh_connection *c);
static void drop_packet_fn(struct ssh_connection *c);
static void diffie_hellman_group1_sha1_fn(struct ssh_connection *c);
static int kex_dh_hash(__u8 *result, char *client_version_string, char *server_version_string, __u8 *client_kexinit_msg, unsigned client_kexinit_msg_len, __u8 *server_kexinit_msg, unsigned server_kexinit_msg_len, __u8 *server_host_key_blob, unsigned server_host_key_blob_len, BIGNUM *dh_client_pub, BIGNUM *dh_server_pub, BIGNUM *shared_secret);
static __u8 *derive_key(struct ssh_connection *c, BIGNUM *shared_secret, __u8 *hash, __u8 id, unsigned need);
static void kexdh_wait_for_newkeys(struct ssh_connection *c);

static struct kex_type diffie_hellman_group1_sha1 = {
	"diffie-hellman-group1-sha1", diffie_hellman_group1_sha1_fn,
};

struct kex_type *kex_types[] = {
	&diffie_hellman_group1_sha1,
	NULL
};

void sshd_kexinit(struct ssh_connection *c)
{
	char *s;
	unsigned l;
	int i;
	if (__unlikely(ssh_add_byte(c, SSH_MSG_KEXINIT))) return;
	if (__unlikely(ssh_add_int32(c, arc4random()))) return;
	if (__unlikely(ssh_add_int32(c, arc4random()))) return;
	if (__unlikely(ssh_add_int32(c, arc4random()))) return;
	if (__unlikely(ssh_add_int32(c, arc4random()))) return;
	init_str(&s, &l);
	for (i = 0; kex_types[i]; i++) {
		if (i && __unlikely(add_chr_to_str(c, &s, &l, ','))) return;
		if (__unlikely(add_to_str(c, &s, &l, kex_types[i]->name))) return;
	}
	if (__unlikely(ssh_add_string(c, s, l))) {__slow_free(s); return;}
	free(s);
	init_str(&s, &l);
	for (i = 0; public_types[i]; i++) {
		if (i && __unlikely(add_chr_to_str(c, &s, &l, ','))) return;
		if (__unlikely(add_to_str(c, &s, &l, public_types[i]->name))) return;
	}
	if (__unlikely(ssh_add_string(c, s, l))) {__slow_free(s); return; }
	free(s);
	init_str(&s, &l);
	for (i = 0; cipher_types[i]; i++) {
		if (i && __unlikely(add_chr_to_str(c, &s, &l, ','))) return;
		if (__unlikely(add_to_str(c, &s, &l, cipher_types[i]->name))) return;
	}
	if (__unlikely(ssh_add_string(c, s, l))) {__slow_free(s); return; }
	if (__unlikely(ssh_add_string(c, s, l))) {__slow_free(s); return; }
	free(s);
	init_str(&s, &l);
	for (i = 0; mac_types[i]; i++) {
		if (i && __unlikely(add_chr_to_str(c, &s, &l, ','))) return;
		if (__unlikely(add_to_str(c, &s, &l, mac_types[i]->name))) return;
	}
	if (__unlikely(ssh_add_string(c, s, l))) {__slow_free(s); return; }
	if (__unlikely(ssh_add_string(c, s, l))) {__slow_free(s); return; }
	free(s);
	init_str(&s, &l);
	for (i = 0; compress_types[i]; i++) {
		if (i && __unlikely(add_chr_to_str(c, &s, &l, ','))) return;
		if (__unlikely(add_to_str(c, &s, &l, compress_types[i]->name))) return;
	}
	if (__unlikely(ssh_add_string(c, s, l))) {__slow_free(s); return; }
	if (__unlikely(ssh_add_string(c, s, l))) {__slow_free(s); return; }
	free(s);
	if (__unlikely(ssh_add_string(c, NULL, 0))) return;
	if (__unlikely(ssh_add_string(c, NULL, 0))) return;
	if (__unlikely(ssh_add_byte(c, 0))) return;
	if (__unlikely(ssh_add_int32(c, 0))) return;
	if (__unlikely(!(c->server_kexinit_msg = reallocf(c->server_kexinit_msg, c->out_stream_len)))) {
		debug_fatal(c, "Can't alloc server kexinit msg");
		abort_ssh_connection(c);
		return;
	}
	memcpy(c->server_kexinit_msg, c->out_stream, c->out_stream_len);
	c->server_kexinit_msg_len = c->out_stream_len;
	if (__unlikely(ssh_send_packet(c, 0))) return;
	c->in_upcall = kexinit_received;
	ssh_get_next_packet(c);
}

static void kexinit_received(struct ssh_connection *c)
{
	unsigned sp;
	__u8 cmd;
	char *party_kex_types;
	unsigned party_kex_types_len;
	char *party_public_types;
	unsigned party_public_types_len;
	char *party_cipher_ctos_types;
	unsigned party_cipher_ctos_types_len;
	char *party_cipher_stoc_types;
	unsigned party_cipher_stoc_types_len;
	char *party_mac_ctos_types;
	unsigned party_mac_ctos_types_len;
	char *party_mac_stoc_types;
	unsigned party_mac_stoc_types_len;
	char *party_compress_ctos_types;
	unsigned party_compress_ctos_types_len;
	char *party_compress_stoc_types;
	unsigned party_compress_stoc_types_len;
	__u8 kex_follows;
	char str[__MAX_STR_LEN];
	unsigned p, strp;
	int drop_packet;
	struct kex_type *kex_type;
	struct public_type *public_type;
	struct cipher_type *cipher_ctos;
	struct cipher_type *cipher_stoc;
	struct mac_type *mac_ctos;
	struct mac_type *mac_stoc;
	struct compress_type *compress_ctos;
	struct compress_type *compress_stoc;
	again:
	sp = 0;
	if (__unlikely(ssh_get_byte(c, &sp, &cmd))) return;
	if (__unlikely(cmd != SSH_MSG_KEXINIT)) {
		if (ssh_unknown_packet(c, 0)) return;
		goto again;
	}
	if (__unlikely(ssh_get_int32(c, &sp, (__u32 *)(void *)&KERNEL$LIST_END))) return;
	if (__unlikely(ssh_get_int32(c, &sp, (__u32 *)(void *)&KERNEL$LIST_END))) return;
	if (__unlikely(ssh_get_int32(c, &sp, (__u32 *)(void *)&KERNEL$LIST_END))) return;
	if (__unlikely(ssh_get_int32(c, &sp, (__u32 *)(void *)&KERNEL$LIST_END))) return;
	if (__unlikely(ssh_get_string(c, &sp, &party_kex_types, &party_kex_types_len))) return;
	if (__unlikely(ssh_get_string(c, &sp, &party_public_types, &party_public_types_len))) return;
	if (__unlikely(ssh_get_string(c, &sp, &party_cipher_ctos_types, &party_cipher_ctos_types_len))) return;
	if (__unlikely(ssh_get_string(c, &sp, &party_cipher_stoc_types, &party_cipher_stoc_types_len))) return;
	if (__unlikely(ssh_get_string(c, &sp, &party_mac_ctos_types, &party_mac_ctos_types_len))) return;
	if (__unlikely(ssh_get_string(c, &sp, &party_mac_stoc_types, &party_mac_stoc_types_len))) return;
	if (__unlikely(ssh_get_string(c, &sp, &party_compress_ctos_types, &party_compress_ctos_types_len))) return;
	if (__unlikely(ssh_get_string(c, &sp, &party_compress_stoc_types, &party_compress_stoc_types_len))) return;
	if (__unlikely(ssh_get_string(c, &sp, (char **)(void *)&KERNEL$LIST_END, (unsigned *)(void *)&KERNEL$LIST_END))) return;
	if (__unlikely(ssh_get_string(c, &sp, (char **)(void *)&KERNEL$LIST_END, (unsigned *)(void *)&KERNEL$LIST_END))) return;
	if (__unlikely(ssh_get_byte(c, &sp, &kex_follows))) return;
	if (__unlikely(ssh_get_int32(c, &sp, (__u32 *)(void *)&KERNEL$LIST_END))) return;
	if (__unlikely(!(c->client_kexinit_msg = reallocf(c->client_kexinit_msg, sp)))) {
		debug_fatal(c, "Can't alloc client kexinit msg");
		abort_ssh_connection(c);
		return;
	}
	memcpy(c->client_kexinit_msg, c->in_stream + c->in_stream_start, sp);
	c->client_kexinit_msg_len = sp;
	ssh_got_data(c, &sp);

	drop_packet = 0;

#define iterate(var)	\
	for (p = 0, strp = 0; __likely(strp < __MAX_STR_LEN) && ((__likely(p < var##_len) && __likely(str[strp++] = var[p])) || (p == var##_len && (str[strp++] = ','))); p++) if (__unlikely(str[strp - 1] == ',') && (str[strp - 1] = 0, strp = 0, 1))
	
	iterate(party_kex_types) {
		int i;
		for (i = 0; kex_types[i]; i++) {
			if (__likely(!strcmp(kex_types[i]->name, str))) {
				kex_type = kex_types[i];
				goto kex_done;
			}
			if (kex_follows) drop_packet = 1;
		}
	}
	debug_fatal(c, "Can't agree on kex algorithm (offered %.*s)", party_kex_types_len, party_kex_types);
	abort_ssh_connection(c);
	return;
	kex_done:
	iterate(party_public_types) {
		int i;
		for (i = 0; public_types[i]; i++) {
			if (__likely(!strcmp(public_types[i]->name, str))) {
				public_type = public_types[i];
				goto public_done;
			}
			if (kex_follows) drop_packet = 1;
		}
	}
	debug_fatal(c, "Can't agree on public algorithm (offered %.*s)", party_public_types_len, party_public_types);
	abort_ssh_connection(c);
	return;
	public_done:
	iterate(party_cipher_ctos_types) {
		int i;
		for (i = 0; cipher_types[i]; i++) {
			if (!strcmp(cipher_types[i]->name, str)) {
				cipher_ctos = cipher_types[i];
				goto cipher_ctos_done;
			}
		}
	}
	debug_fatal(c, "Can't agree on cipher client-to-server algorithm (offered %.*s)", party_cipher_ctos_types_len, party_cipher_ctos_types);
	abort_ssh_connection(c);
	return;
	cipher_ctos_done:
	iterate(party_cipher_stoc_types) {
		int i;
		for (i = 0; cipher_types[i]; i++) {
			if (!strcmp(cipher_types[i]->name, str)) {
				cipher_stoc = cipher_types[i];
				goto cipher_stoc_done;
			}
		}
	}
	debug_fatal(c, "Can't agree on cipher server-to-client algorithm (offered %.*s)", party_cipher_stoc_types_len, party_cipher_stoc_types);
	abort_ssh_connection(c);
	return;
	cipher_stoc_done:
	iterate(party_mac_ctos_types) {
		int i;
		for (i = 0; mac_types[i]; i++) {
			if (!strcmp(mac_types[i]->name, str)) {
				mac_ctos = mac_types[i];
				goto mac_ctos_done;
			}
		}
	}
	debug_fatal(c, "Can't agree on mac client-to-server algorithm (offered %.*s)", party_mac_ctos_types_len, party_mac_ctos_types);
	abort_ssh_connection(c);
	return;
	mac_ctos_done:
	iterate(party_mac_stoc_types) {
		int i;
		for (i = 0; mac_types[i]; i++) {
			if (!strcmp(mac_types[i]->name, str)) {
				mac_stoc = mac_types[i];
				goto mac_stoc_done;
			}
		}
	}
	debug_fatal(c, "Can't agree on mac server-to-client algorithm (offered %.*s)", party_mac_stoc_types_len, party_mac_stoc_types);
	abort_ssh_connection(c);
	return;
	mac_stoc_done:
	iterate(party_compress_ctos_types) {
		int i;
		for (i = 0; compress_types[i]; i++) {
			if (!strcmp(compress_types[i]->name, str)) {
				compress_ctos = compress_types[i];
				goto compress_ctos_done;
			}
		}
	}
	debug_fatal(c, "Can't agree on compress client-to-server algorithm (offered %.*s)", party_compress_ctos_types_len, party_compress_ctos_types);
	abort_ssh_connection(c);
	return;
	compress_ctos_done:
	iterate(party_compress_stoc_types) {
		int i;
		for (i = 0; compress_types[i]; i++) {
			if (!strcmp(compress_types[i]->name, str)) {
				compress_stoc = compress_types[i];
				goto compress_stoc_done;
			}
		}
	}
	debug_fatal(c, "Can't agree on compress server-to-client algorithm (offered %.*s)", party_compress_stoc_types_len, party_compress_stoc_types);
	abort_ssh_connection(c);
	return;
	compress_stoc_done:
#undef iterate
	/*__debug_printf("negotiated: kex %s, public %s, cipher_ctos %s, cipher_stoc %s, mac_ctos %s, mac_stoc %s, compress_ctos %s, compress_stoc %s\n", kex_type->name, public_type->name, cipher_ctos->name, cipher_stoc->name, mac_ctos->name, mac_stoc->name, compress_ctos->name, compress_stoc->name);*/
	c->kex_type = kex_type;
	c->public_type = public_type;
	c->cipher_ctos = cipher_ctos;
	c->cipher_stoc = cipher_stoc;
	c->mac_ctos = mac_ctos;
	c->mac_stoc = mac_stoc;
	c->compress_ctos = compress_ctos;
	c->compress_stoc = compress_stoc;
	if (drop_packet) c->in_upcall = drop_packet_fn;
	else c->in_upcall = kex_type->kex_fn;
	c->in_upcall(c);
}

static void drop_packet_fn(struct ssh_connection *c)
{
	unsigned sp = 0;
	__u8 cmd;
	if (__unlikely(ssh_get_byte(c, &sp, &cmd))) return;
	ssh_flush_stream(c);
	c->in_upcall = c->kex_type->kex_fn;
	c->in_upcall(c);
}

static int dh_set_group1(DH *dh);
static int dh_pub_is_valid(DH *dh, BIGNUM *bn);

static void diffie_hellman_group1_sha1_fn(struct ssh_connection *c)
{
	unsigned sp;
	__u8 cmd;
	BIGNUM dh_client_pub, shared_secret;
	unsigned klen, kout, server_host_key_blob_len, signature_len;
	__u8 *kbuf, *server_host_key_blob = NULL, *signature = NULL;
	__u8 hash[SHA_DIGEST_LENGTH];
	BN_init(&dh_client_pub);
	BN_init(&shared_secret);
	again:
	sp = 0;
	if (__unlikely(ssh_get_byte(c, &sp, &cmd))) goto ret;
	if (__unlikely(cmd != SSH_MSG_KEXDH_INIT)) {
		if (ssh_unknown_packet(c, 0)) goto ret;
		goto again;
	}
	if (__unlikely(ssh_get_bignum(c, &sp, &dh_client_pub))) goto ret;
	ssh_got_data(c, &sp);
	dh_again:
	if (__unlikely(c->dh != NULL)) DH_free(c->dh);
	if (__unlikely(!(c->dh = DH_new()))) {
		debug_fatal(c, "Can't alloc DH");
		abort_ssh_connection(c);
		goto ret;
	}
	if (__unlikely(dh_set_group1(c->dh))) {
		debug_fatal(c, "Can't set DH group1");
		abort_ssh_connection(c);
		goto ret;
	}
	if (__unlikely(!DH_generate_key(c->dh))) {
		debug_fatal(c, "Can't generate DH key");
		abort_ssh_connection(c);
		goto ret;
	}
	if (__unlikely(!dh_pub_is_valid(c->dh, c->dh->pub_key))) goto dh_again;
	if (__unlikely(!dh_pub_is_valid(c->dh, &dh_client_pub))) {
		debug_fatal(c, "Invalid bignum received");
		abort_ssh_connection(c);
		goto ret;
	}
	klen = DH_size(c->dh);
	kbuf = malloc(klen);
	if (__unlikely(!kbuf)) {
		debug_fatal(c, "Can't alloc kbuf");
		abort_ssh_connection(c);
		goto ret;
	}
	kout = DH_compute_key(kbuf, &dh_client_pub, c->dh);
	BN_bin2bn(kbuf, kout, &shared_secret);
	memset(kbuf, 0, klen);
	free(kbuf);

	if (__unlikely(c->public_type->get_blob(&server_host_key_blob, &server_host_key_blob_len))) {
		debug_fatal(c, "Can't get blob for key %s", c->public_type->name);
		abort_ssh_connection(c);
		goto ret;
	}

	if (__unlikely(kex_dh_hash(hash, c->client_version_string, idstring, c->client_kexinit_msg, c->client_kexinit_msg_len, c->server_kexinit_msg, c->server_kexinit_msg_len, server_host_key_blob, server_host_key_blob_len, &dh_client_pub, c->dh->pub_key, &shared_secret))) {
		debug_fatal(c, "Can't hash");
		abort_ssh_connection(c);
		goto ret;
	}

	if (__likely(!c->session_id)) {
		if (__unlikely(!(c->session_id = malloc(sizeof hash)))) {
			debug_fatal(c, "Can't alloc session id");
			abort_ssh_connection(c);
			goto ret;
		}
		memcpy(c->session_id, hash, sizeof hash);
		c->session_id_len = sizeof hash;
	}

	if (__unlikely(c->public_type->sign(&signature, &signature_len, hash, sizeof hash))) {
		debug_fatal(c, "Can't sign");
		abort_ssh_connection(c);
		goto ret;
	}

	if (__unlikely(ssh_add_byte(c, SSH_MSG_KEXDH_REPLY))) goto ret;
	if (__unlikely(ssh_add_string(c, server_host_key_blob, server_host_key_blob_len))) goto ret;
	if (__unlikely(ssh_add_bignum(c, c->dh->pub_key))) goto ret;
	if (__unlikely(ssh_add_string(c, signature, signature_len))) goto ret;
	if (__unlikely(ssh_send_packet(c, 1))) goto ret;

	if (__unlikely(ssh_add_byte(c, SSH_MSG_NEWKEYS))) goto ret;
	if (__unlikely(ssh_send_packet(c, 0))) goto ret;

	free(c->iv_ctos);
	c->iv_ctos = derive_key(c, &shared_secret, hash, 'A', c->cipher_ctos->blocksize);
	if (__unlikely(!c->iv_ctos)) {
		debug_fatal(c, "Can't derive iv_ctos");
		goto ret;
	}
	free(c->iv_stoc);
	c->iv_stoc = derive_key(c, &shared_secret, hash, 'B', c->cipher_stoc->blocksize);
	if (__unlikely(!c->iv_stoc)) {
		debug_fatal(c, "Can't derive iv_stoc");
		goto ret;
	}
	free(c->key_ctos);
	c->key_ctos = derive_key(c, &shared_secret, hash, 'C', c->cipher_ctos->keysize);
	if (__unlikely(!c->key_ctos)) {
		debug_fatal(c, "Can't derive key_ctos");
		goto ret;
	}
	free(c->key_stoc);
	c->key_stoc = derive_key(c, &shared_secret, hash, 'D', c->cipher_stoc->keysize);
	if (__unlikely(!c->key_stoc)) {
		debug_fatal(c, "Can't derive key_stoc");
		goto ret;
	}
	free(c->mac_key_ctos);
	c->mac_key_ctos = derive_key(c, &shared_secret, hash, 'E', c->mac_ctos->keysize);
	if (__unlikely(!c->mac_key_ctos)) {
		debug_fatal(c, "Can't derive mac_key_ctos");
		goto ret;
	}
	free(c->mac_key_stoc);
	c->mac_key_stoc = derive_key(c, &shared_secret, hash, 'F', c->cipher_stoc->keysize);
	if (__unlikely(!c->mac_key_stoc)) {
		debug_fatal(c, "Can't derive mac_key_stoc");
		goto ret;
	}

	if (__unlikely(setup_out_cipher(c, c->cipher_stoc, c->key_stoc, c->iv_stoc))) {
		debug_fatal(c, "Can't setup out cipher");
		goto ret;
	}
	if (__unlikely(setup_out_mac(c, c->mac_stoc, c->mac_key_stoc))) {
		debug_fatal(c, "Can't setup out mac");
		goto ret;
	}

	/* todo: setup out compress context */

	c->in_upcall = kexdh_wait_for_newkeys;
	ssh_get_next_packet(c);

	ret:
	BN_free(&dh_client_pub);
	BN_free(&shared_secret);
	free(server_host_key_blob);
	free(signature);
}

static int dh_set_group1(DH *dh)
{
	static __const__ char group1_p[] =
	  "FFFFFFFF" "FFFFFFFF" "C90FDAA2" "2168C234" "C4C6628B" "80DC1CD1"
	  "29024E08" "8A67CC74" "020BBEA6" "3B139B22" "514A0879" "8E3404DD"
	  "EF9519B3" "CD3A431B" "302B0A6D" "F25F1437" "4FE1356D" "6D51C245"
	  "E485B576" "625E7EC6" "F44C42E9" "A637ED6B" "0BFF5CB6" "F406B7ED"
	  "EE386BFB" "5A899FA5" "AE9F2411" "7C4B1FE6" "49286651" "ECE65381"
	  "FFFFFFFF" "FFFFFFFF";
	static __const__ __u8 group1_g[] = "2";
	if (__unlikely(!BN_hex2bn(&dh->p, group1_p))) return -1;
	if (__unlikely(!BN_hex2bn(&dh->g, group1_g))) return -1;
	return 0;
}

static int dh_pub_is_valid(DH *dh, BIGNUM *bn)
{
	int i, n, bits_set;
	if (__unlikely(bn->neg)) return 0;
	n = BN_num_bits(bn);
	bits_set = 0;
		/* I think '<' would not be sufficient, but
		   OpenSSH hash <= there ... */
	for (i = 0; i <= n; i++) if (BN_is_bit_set(bn, i)) bits_set++;
	if (__likely(bits_set > 1) & __likely((BN_cmp(bn, dh->p) == -1)))
		return 1;
	return 0;
}

static int kex_dh_hash(__u8 *result, char *client_version_string, char *server_version_string, __u8 *client_kexinit_msg, unsigned client_kexinit_msg_len, __u8 *server_kexinit_msg, unsigned server_kexinit_msg_len, __u8 *server_host_key_blob, unsigned server_host_key_blob_len, BIGNUM *dh_client_pub, BIGNUM *dh_server_pub, BIGNUM *shared_secret)
{
	SHA_CTX ctx;
	__u32 len;
	__u8 *bnbuf;
	SHA1_Init(&ctx);
	len = __32CPU2BE(strlen(client_version_string));
	SHA1_Update(&ctx, &len, sizeof len);
	SHA1_Update(&ctx, client_version_string, __32BE2CPU(len));
	len = __32CPU2BE(strlen(server_version_string));
	SHA1_Update(&ctx, &len, sizeof len);
	SHA1_Update(&ctx, server_version_string, __32BE2CPU(len));
	len = __32CPU2BE(client_kexinit_msg_len);
	SHA1_Update(&ctx, &len, sizeof len);
	SHA1_Update(&ctx, client_kexinit_msg, client_kexinit_msg_len);
	len = __32CPU2BE(server_kexinit_msg_len);
	SHA1_Update(&ctx, &len, sizeof len);
	SHA1_Update(&ctx, server_kexinit_msg, server_kexinit_msg_len);
	len = __32CPU2BE(server_host_key_blob_len);
	SHA1_Update(&ctx, &len, sizeof len);
	SHA1_Update(&ctx, server_host_key_blob, server_host_key_blob_len);
	if (__unlikely(bn_2_string(dh_client_pub, &bnbuf, &len))) return -1;
	len = __32CPU2BE(len);
	SHA1_Update(&ctx, &len, sizeof len);
	SHA1_Update(&ctx, bnbuf, __32BE2CPU(len));
	free(bnbuf);
	if (__unlikely(bn_2_string(dh_server_pub, &bnbuf, &len))) return -1;
	len = __32CPU2BE(len);
	SHA1_Update(&ctx, &len, sizeof len);
	SHA1_Update(&ctx, bnbuf, __32BE2CPU(len));
	free(bnbuf);
	if (__unlikely(bn_2_string(shared_secret, &bnbuf, &len))) return -1;
	len = __32CPU2BE(len);
	SHA1_Update(&ctx, &len, sizeof len);
	SHA1_Update(&ctx, bnbuf, __32BE2CPU(len));
	free(bnbuf);
	SHA1_Final(result, &ctx);
	return 0;
}

static __u8 *derive_key(struct ssh_connection *c, BIGNUM *shared_secret, __u8 *hash, __u8 id, unsigned need)
{
	__u8 *digest, *bnbuf;
	__u32 bnlen;
	unsigned have;
	SHA_CTX ctx;
	if (__unlikely(!need)) return malloc(0);
	if (__unlikely(!(digest = malloc(ROUNDUP(need, SHA_DIGEST_LENGTH)))))
		return NULL;
	if (__unlikely(bn_2_string(shared_secret, &bnbuf, &bnlen))) {
		__slow_free(digest);
		return NULL;
	}
	bnlen = __32CPU2BE(bnlen);
	SHA1_Init(&ctx);
		/* !!! SSH_BUG_DERIVEKEY */
	SHA1_Update(&ctx, &bnlen, sizeof bnlen);
	SHA1_Update(&ctx, bnbuf, __32BE2CPU(bnlen));
	SHA1_Update(&ctx, hash, SHA_DIGEST_LENGTH);
	SHA1_Update(&ctx, &id, 1);
	SHA1_Update(&ctx, c->session_id, c->session_id_len);
	SHA1_Final(digest, &ctx);
	for (have = SHA_DIGEST_LENGTH; need > have; have += SHA_DIGEST_LENGTH) {
		SHA1_Init(&ctx);
			/* !!! SSH_BUG_DERIVEKEY */
		SHA1_Update(&ctx, &bnlen, sizeof bnlen);
		SHA1_Update(&ctx, bnbuf, __32BE2CPU(bnlen));
		SHA1_Update(&ctx, hash, SHA_DIGEST_LENGTH);
		SHA1_Update(&ctx, digest, have);
		SHA1_Final(digest + have, &ctx);
	}
	free(bnbuf);
	return digest;
}

static void kexdh_wait_for_newkeys(struct ssh_connection *c)
{
	unsigned sp;
	__u8 cmd;
	again:
	sp = 0;
	if (__unlikely(ssh_get_byte(c, &sp, &cmd))) return;
	if (__unlikely(cmd != SSH_MSG_NEWKEYS)) {
		if (ssh_unknown_packet(c, 0)) return;
		goto again;
	}
	ssh_got_data(c, &sp);
	ssh_flush_stream(c);
	if (__unlikely(setup_in_cipher(c, c->cipher_stoc, c->key_ctos, c->iv_ctos))) {
		debug_fatal(c, "Can't setup in cipher");
		abort_ssh_connection(c);
		return;
	}
	if (__unlikely(setup_in_mac(c, c->mac_ctos, c->mac_key_ctos))) {
		debug_fatal(c, "Can't setup in mac");
		abort_ssh_connection(c);
		return;
	}

	/* todo: setup in compress context */

	free(c->iv_ctos); c->iv_ctos = NULL;
	free(c->iv_stoc); c->iv_stoc = NULL;
	free(c->key_ctos); c->key_ctos = NULL;
	free(c->key_stoc); c->key_stoc = NULL;
	free(c->mac_key_ctos); c->mac_key_ctos = NULL;
	free(c->mac_key_stoc); c->mac_key_stoc = NULL;
	c->in_upcall = c->in_afterkex_upcall;
	c->in_afterkex_upcall = NULL;
	ssh_get_next_packet(c);
}
