#include <STDLIB.H>
#include <SIGNAL.H>
#include <SPAD/DEV.H>
#include <SPAD/LIBC.H>
#include <ARCH/SETUP.H>

#include "STRUCT.H"
#include "SSHD.H"
#include "OPS.H"

static void fold_in_buffer(struct ssh_connection *c)
{
	CHECKH("fold_in_buffer");
	memmove(c->in_buffer, c->in_buffer + c->in_buffer_start, c->in_buffer_end - c->in_buffer_start);
	c->in_buffer_end -= c->in_buffer_start;
	c->in_buffer_start = 0;
}

static void fold_in_stream(struct ssh_connection *c)
{
	CHECKH("fold_in_stream");
	memmove(c->in_stream, c->in_stream + c->in_stream_start, c->in_stream_end - c->in_stream_start);
	c->in_stream_end -= c->in_stream_start;
	c->in_stream_start = 0;
}

void ssh_read_more(struct ssh_connection *c)
{
	unsigned rs;
	CHECKH("ssh_read_more");
	again:
	if (__unlikely(!(rs = __alloc_size(c->in_buffer) - c->in_buffer_end))) {
		if (__likely(c->in_buffer_start)) {
			fold_in_buffer(c);
			goto again;
		}
		if (__unlikely(__alloc_size(c->in_buffer) >= MAX_PACKET_LENGTH)) {
			debug_fatal(c, "Maximum packet length exceeded in in_buffer (current size %ld)", (unsigned long)__alloc_size(c->in_buffer));
			abort_ssh_connection(c);
			return;
		}
		if (__unlikely(!(c->in_buffer = reallocf(c->in_buffer, __alloc_size(c->in_buffer) + 256)))) {
			debug_fatal(c, "Can't realloc in_buffer");
			abort_ssh_connection(c);
			return;
		}
		rs = __alloc_size(c->in_buffer) - c->in_buffer_end;
	}
	c->in_sock.v.ptr = (unsigned long)(c->in_buffer + c->in_buffer_end);
	c->in_sock.v.len = rs;
	c->in_sock.progress = 0;
	CHECKH("ssh_read_more2");
	RAISE_SPL(SPL_BOTTOM);
	c->outstanding++;
	CALL_IORQ(&c->in_sock, KERNEL$READ);
	LOWER_SPL(SPL_ZERO);
}

static void ssh_read_id(struct ssh_connection *c);

static __finline__ void do_event(struct ssh_connection *c)
{
	if (__unlikely(c->xflags & SSHD_ON_QUEUE)) return;
#if __DEBUG >= 3
	{
		struct ssh_connection *q;
		LIST_FOR_EACH(q, &in_sock_queue, struct ssh_connection, in_sock_list) if (__unlikely(q == c)) KERNEL$SUICIDE("ssh_in_sock: CONNECTION ALREADY ON QUEUE");
	}
#endif
	c->xflags |= SSHD_ON_QUEUE;
	ADD_TO_LIST_END(&in_sock_queue, &c->in_sock_list);
	if (do_interrupt) {
		KERNEL$INTR_SYSCALL();
	}
}

DECL_AST(ssh_in_sock, SPL_BOTTOM, SIORQ)
{
	struct ssh_connection *c = GET_STRUCT(RQ, struct ssh_connection, in_sock);
	CHECKH("ssh_in_sock");
	c->outstanding--;
	if (__unlikely(c->in_sock.status <= 0)) {
		if (c->in_sock.status != -EINTR && c->in_sock.status != 0) debug_fatal(c, "Input error %ld", c->in_sock.status);
		c->xflags |= SSHD_NEED_ABORT;
		goto add_to_q;
	}
	c->in_buffer_end += c->in_sock.status;
	add_to_q:
	c->xflags |= SSHD_CALL_IN_UPCALL;
	do_event(c);
	RETURN;
}

DECL_AST(ssh_out_tty, SPL_BOTTOM, IOCTLRQ)
{
	struct ssh_connection *c = GET_STRUCT(RQ, struct ssh_connection, out_tty);
	CHECKH("ssh_out_tty");
	c->outstanding--;
	if (__unlikely(c->out_tty.status < 0)) {
		if (c->out_tty.status != -EINTR) debug_fatal(c, "TTYSRVR error %ld on request %x", c->out_tty.status, c->out_tty.ioctl);
		c->xflags |= SSHD_NEED_ABORT;
	}
	c->xflags |= SSHD_CALL_OUT_TTY_UPCALL;
	do_event(c);
	RETURN;
}

DECL_AST(ssh_in_tty, SPL_BOTTOM, IOCTLRQ)
{
	struct ssh_connection *c = GET_STRUCT(RQ, struct ssh_connection, in_tty);
	CHECKH("ssh_in_tty");
	c->outstanding--;
	if (__unlikely(c->in_tty.status < 0)) {
		if (c->in_tty.status != -EINTR) debug_fatal(c, "TTYSRVR error %ld on request %x", c->in_tty.status, c->in_tty.ioctl);
		c->xflags |= SSHD_NEED_ABORT;
	}
	c->xflags |= SSHD_CALL_IN_TTY_UPCALL;
	do_event(c);
	RETURN;
}

static void ssh_read_id(struct ssh_connection *c)
{
	__u8 *p;
	unsigned l, i;
	CHECKH("ssh_read_id");
	nextl:
	p = c->in_buffer + c->in_buffer_start;
	l = c->in_buffer_end - c->in_buffer_start;
	for (i = 0; i < l; i++) if (__unlikely(p[i] == '\n')) goto fz;
	rm:
	ssh_read_more(c);
	return;
	fz:
	c->in_buffer_start += i + 1;
	if (__unlikely(i < 5) || __unlikely(memcmp(p, "SSH-", 4))) goto nextl;
	if (__likely(p[4] == '2')) goto ok;
	if (__likely(i >= 8) && __likely(p[4] == '1') && __likely(p[5] == '.') && __likely(p[6] == '9') && __likely(p[7] == '9')) goto ok;
	debug_fatal(c, "Incompatible client version %d", p[4] - '0');
	abort_ssh_connection(c);
	return;
	ok:
	if (p[i - 1] == '\r') i--;
	if (__unlikely(!(c->client_version_string = reallocf(c->client_version_string, i + 1)))) {
		debug_fatal(c, "Can't alloc client version string");
		abort_ssh_connection(c);
		return;
	}
	memcpy(c->client_version_string, p, i);
	c->client_version_string[i] = 0;
	/*__debug_printf("client: %s\n", c->client_version_string);*/
	c->client_prog_id = strchr(c->client_version_string + 4, '-');
	if (__unlikely(!c->client_prog_id)) c->client_prog_id = c->client_version_string;
	else c->client_prog_id++;
	c->flags &= ~SSHD_SKIP_ID;
	if (c->in_buffer_start == c->in_buffer_end) goto rm;
	c->in_upcall(c);
	return;
}

void ssh_get_next_packet(struct ssh_connection *c)
{
	int r;
	__u32 len;
	__u8 *cmp, *res;
	unsigned cmpl, resl;
	__u8 mac[EVP_MAX_MD_SIZE];
	CHECKH("ssh_get_next_packet");
	if (__unlikely(c->flags & SSHD_SKIP_ID)) {
		ssh_read_id(c);
		return;
	}
	if (__unlikely(c->in_packet_len)) goto check_mac;
	new_data:
	if (__unlikely(c->in_buffer_end - c->in_buffer_start < c->in_cipher.type->blocksize)) {
		ssh_read_more(c);
		return;
	}
	if (__unlikely(c->in_buffer_start & (sizeof(unsigned long) - 1))) {
		fold_in_buffer(c);
	}
	if (__unlikely(c->in_packet_len + c->in_cipher.type->blocksize + c->in_mac.type->size > __alloc_size(c->in_packet))) {
		if (__unlikely(__alloc_size(c->in_packet) >= MAX_PACKET_LENGTH)) {
			debug_fatal(c, "Maximum packet length exceeded in in_packet (current size %ld)", (unsigned long)__alloc_size(c->in_packet));
			abort_ssh_connection(c);
			return;
		}
		if (__unlikely(!(c->in_packet = reallocf(c->in_packet, c->in_packet_len + c->in_cipher.type->blocksize + c->in_mac.type->size)))) {
			debug_fatal(c, "Can't realloc in_packet");
			abort_ssh_connection(c);
			return;
		}
	}
	if (__unlikely(!EVP_Cipher(&c->in_cipher.evp_ctx, c->in_packet + c->in_packet_len, c->in_buffer + c->in_buffer_start, c->in_cipher.type->blocksize))) {
		debug_fatal(c, "EVP_Cipher decrypt failed");
		abort_ssh_connection(c);
		return;
	}
	c->in_packet_len += c->in_cipher.type->blocksize;
	c->in_buffer_start += c->in_cipher.type->blocksize;
	check_mac:
#define pkt(c)	((struct ssh_packet *)c->in_packet)
	len = __32BE2CPU(pkt(c)->length);
	len += 4;
	if (__unlikely(len > MAX_PACKET_LENGTH) || __unlikely(len < 4 + 9)) {
		debug_fatal(c, "Received packet too big (%08X)", (unsigned)len);
		abort_ssh_connection(c);
		return;
	}
	if (__likely(len > c->in_packet_len)) goto new_data;
	if (__unlikely(len < c->in_packet_len)) {
		debug_fatal(c, "Packet length not multiple of block size (block size %d, requested lenght %d, received %d", c->in_cipher.type->blocksize, len, c->in_packet_len);
		abort_ssh_connection(c);
		return;
	}
	if (__unlikely(c->in_buffer_end - c->in_buffer_start < c->in_mac.type->size)) {
		ssh_read_more(c);
		return;
	}
	c->in_mac.type->mac(&c->in_mac, c->in_packet_seq, c->in_packet, c->in_packet_len, mac);
	if (__unlikely(memcmp(mac, c->in_buffer + c->in_buffer_start, c->in_mac.type->size))) {
		debug_fatal(c, "MAC check failed");
		abort_ssh_connection(c);
		return;
	}
	c->in_buffer_start += c->in_mac.type->size;
	c->in_packet_seq++;
	if (__unlikely(pkt(c)->padding_length > len - 5)) {
		debug_fatal(c, "Padding (%d) > length (%d)", pkt(c)->padding_length, len - 5);
		abort_ssh_connection(c);
		return;
	}
	cmp = pkt(c)->data;
	cmpl = len - 5 - pkt(c)->padding_length;
	inflate_again:
	if (__unlikely(!(resl = __alloc_size(c->in_stream) - c->in_stream_end))) {
		if (__likely(c->in_stream_start)) {
			fold_in_stream(c);
			goto inflate_again;
		}
		if (__unlikely(__alloc_size(c->in_stream) >= MAX_STREAM_LENGTH)) {
			debug_fatal(c, "Maximum stream length exceeded in in_stream (current size %ld)", (unsigned long)__alloc_size(c->in_stream));
			abort_ssh_connection(c);
			return;
		}
		if (__unlikely(!(c->in_stream = reallocf(c->in_stream, __alloc_size(c->in_stream) + 256)))) {
			debug_fatal(c, "Can't realloc in_stream");
			abort_ssh_connection(c);
			return;
		}
		resl = __alloc_size(c->in_stream) - c->in_stream_end;
	}
	res = c->in_stream + c->in_stream_end;
	if (__unlikely(r = c->in_compress.type->inflate(&c->in_compress, &cmp, &cmpl, &res, &resl))) {
		debug_fatal(c, "Inflate error: %d", r);
		abort_ssh_connection(c);
		return;
	}
	c->in_stream_end = res - c->in_stream;
	if (__unlikely(cmpl)) goto inflate_again;
	c->in_packet_len = 0;
	CHECKH("ssh_get_next_packet_2");
	if (__unlikely(!c->in_upcall)) KERNEL$SUICIDE("ssh_get_next_packet: %s: NO UPCALL", c->tty_name);
	c->in_upcall(c);
#undef pkt
}

int ssh_send_packet(struct ssh_connection *c, int hold)
{
	__u8 *compressed, *dec, *com;
	unsigned compressed_len, decl, coml;
	unsigned pktlen, padding;
	int r;
	CHECKH("ssh_send_packet");
	if (__unlikely(!c->out_stream_len)) {
		debug_internal(c, "Sending zero-sized packet");
		abort_ssh_connection(c);
		return -1;
	}
	if (__unlikely(!(compressed = malloc(5 + c->out_stream_len + c->out_cipher.type->blocksize + (MIN_PADDING_LENGTH - 1))))) {
		debug_fatal(c, "Can't alloc compress buffer");
		abort_ssh_connection(c);
		return -1;
	}
	dec = c->out_stream;
	decl = c->out_stream_len;
	com = compressed + 5;
	coml = __alloc_size(compressed) - 5 - (c->out_cipher.type->blocksize + (MIN_PADDING_LENGTH - 1));
	again_deflate:
	if (__unlikely(r = c->out_compress.type->deflate(&c->out_compress, &dec, &decl, &com, &coml))) {
		__slow_free(compressed);
		debug_fatal(c, "Deflate error: %d", r);
		abort_ssh_connection(c);
		return -1;
	}
	if (__unlikely(decl)) {
		__u8 *cm;
		if (__unlikely(!(cm = reallocf(compressed, __alloc_size(compressed) + 256)))) {
			debug_fatal(c, "Can't realloc compress buffer");
			abort_ssh_connection(c);
			return -1;
		}
		com = cm + (com - compressed);
		compressed = cm;
		coml = __alloc_size(compressed) - (com - compressed);
		goto again_deflate;
	}
	compressed_len = com - compressed;
	pktlen = ROUNDUP(compressed_len + MIN_PADDING_LENGTH, c->out_cipher.type->blocksize);
#if __DEBUG >= 1
	if (__unlikely(pktlen > __alloc_size(compressed)))
		KERNEL$SUICIDE("ssh_send_packet: shot out of memory: %d > %ld", pktlen, (unsigned long)__alloc_size(compressed));
#endif
	padding = pktlen - compressed_len;
#define pkt	((struct ssh_packet *)compressed)
	pkt->length = __32CPU2BE(pktlen - 4);
	pkt->padding_length = padding;
	while (1) {
		__u32 rnd;
		next_rnd:
		rnd = arc4random();
		if (__unlikely(arc4random_error())) {
			if (arc4random_stir()) goto next_rnd;
			__slow_free(compressed);
			debug_fatal(c, "Can't get random number");
			abort_ssh_connection(c);
			return -1;
		}
		if (__likely(padding > 4)) {
			memcpy(compressed + compressed_len, &rnd, 4);
			padding -= 4;
			compressed_len += 4;
			continue;
		}
		memcpy(compressed + compressed_len, &rnd, padding);
		compressed_len += padding;
		break;
	}
#if __DEBUG >= 1
	if (__unlikely(MOD(compressed_len, c->out_cipher.type->blocksize))) {
		__slow_free(compressed);
		debug_internal(c, "Packet length (%d) not multiple of blocksize (%d)", compressed_len, c->out_cipher.type->blocksize);
		abort_ssh_connection(c);
		return -1;
	}
#endif
	RAISE_SPL(SPL_BOTTOM);
	if (__unlikely(c->flags & SSHD_SEND_ID)) {
		unsigned isl = strlen(idstring) + 2;
		if (__unlikely(!(c->out_buffer = reallocf(c->out_buffer, c->out_buffer_len + isl)))) {
			LOWER_SPL(SPL_ZERO);
			__slow_free(compressed);
			debug_fatal(c, "Can't realloc out_buffer");
			abort_ssh_connection(c);
			return -1;
		}
		memcpy(c->out_buffer + c->out_buffer_len, idstring, isl - 2);
		c->out_buffer[c->out_buffer_len + isl - 2] = '\r';
		c->out_buffer[c->out_buffer_len + isl - 1] = '\n';
		c->out_buffer_len += isl;
		c->flags &= ~SSHD_SEND_ID;
	}
	if (__unlikely(c->out_buffer_len + compressed_len + c->out_mac.type->size > __alloc_size(c->out_buffer))) {
		if (__unlikely(!(c->out_buffer = reallocf(c->out_buffer, c->out_buffer_len + compressed_len + c->out_mac.type->size)))) {
			LOWER_SPL(SPL_ZERO);
			__slow_free(compressed);
			debug_fatal(c, "Can't realloc out_buffer");
			abort_ssh_connection(c);
			return -1;
		}
	}
	if (__unlikely(!EVP_Cipher(&c->out_cipher.evp_ctx, c->out_buffer + c->out_buffer_len, compressed, compressed_len))) {
		LOWER_SPL(SPL_ZERO);
		__slow_free(compressed);
		debug_fatal(c, "EVP_Cipher encrypt failed");
		abort_ssh_connection(c);
		return -1;
	}
	c->out_mac.type->mac(&c->out_mac, c->out_packet_seq, compressed, compressed_len, c->out_buffer + c->out_buffer_len + compressed_len);
	free(compressed);
	c->out_buffer_len += compressed_len + c->out_mac.type->size;
	c->out_packet_seq++;
	c->out_stream_len = 0;
	if (__unlikely(c->out_in_progress != NULL) || (__unlikely(hold) && __likely(c->out_buffer_len < MAX_STREAM_LENGTH))) {
		LOWER_SPL(SPL_ZERO);
		return 0;
	}
	c->out_in_progress = c->out_buffer;
	c->out_sock.v.ptr = (unsigned long)c->out_buffer;
	c->out_sock.v.len = c->out_buffer_len;
	c->out_sock.progress = 0;
	c->outstanding++;
	CALL_IORQ(&c->out_sock, KERNEL$WRITE);
	c->out_buffer = NULL;
	c->out_buffer_len = 0;
	LOWER_SPL(SPL_ZERO);
	return 0;
#undef pkt
}

DECL_AST(ssh_out_sock, SPL_BOTTOM, SIORQ)
{
	struct ssh_connection *c = GET_STRUCT(RQ, struct ssh_connection, out_sock);
	CHECKH("ssh_out_sock");
	c->outstanding--;
	if (__unlikely(c->out_sock.status <= 0)) {
		if (c->out_sock.status != -EINTR && c->out_sock.status != 0) debug_fatal(c, "Output error %ld", c->out_sock.status);
		__slow_free(c->out_in_progress);
		c->out_in_progress = NULL;
		c->xflags |= SSHD_NEED_ABORT;
		do_event(c);
		RETURN;
	}
	if (__unlikely(c->out_sock.v.len != 0)) {
		c->out_sock.progress = 0;
		c->outstanding++;
		RETURN_IORQ(&c->out_sock, KERNEL$WRITE);
	}
	if (__unlikely(c->out_buffer_len != 0)) {
		__u8 *o = c->out_in_progress;
		c->out_in_progress = c->out_buffer;
		c->out_sock.v.ptr = (unsigned long)c->out_buffer;
		c->out_sock.v.len = c->out_buffer_len;
		c->out_sock.progress = 0;
		c->outstanding++;
		CALL_IORQ(&c->out_sock, KERNEL$WRITE);
		c->out_buffer = o;
		c->out_buffer_len = 0;
		if (c->xflags & SSHD_CALL_MASK) {
			do_event(c);
		}
		RETURN;
	}
	if (__likely(!c->out_buffer)) c->out_buffer = c->out_in_progress;
	else free(c->out_in_progress);
	c->out_in_progress = NULL;
	RETURN;
}

int ssh_unknown_packet(struct ssh_connection *c, int allow_kex)
{
	unsigned sp = 0;
	__u8 cmd;
	CHECKH("ssh_unknown_packet");
	if (__unlikely(ssh_get_byte(c, &sp, &cmd))) return 1;
	switch (cmd) {
		case SSH_MSG_DISCONNECT:
			abort_ssh_connection(c);
			return 1;
		case SSH_MSG_IGNORE:
			if (__unlikely(ssh_get_string(c, &sp, (void *)&KERNEL$LIST_END, (void *)&KERNEL$LIST_END))) return 1;
			ssh_got_data(c, &sp);
			return 0;
		case SSH_MSG_DEBUG:
			if (__unlikely(ssh_get_byte(c, &sp, &cmd))) return 1;
			if (__unlikely(ssh_get_string(c, &sp, (void *)&KERNEL$LIST_END, (void *)&KERNEL$LIST_END))) return 1;
			if (__unlikely(ssh_get_string(c, &sp, (void *)&KERNEL$LIST_END, (void *)&KERNEL$LIST_END))) return 1;
			ssh_got_data(c, &sp);
			return 0;
		case SSH_MSG_UNIMPLEMENTED:
			if (__unlikely(ssh_get_int32(c, &sp, (void *)&KERNEL$LIST_END))) return 1;
			ssh_got_data(c, &sp);
			return 0;
		case SSH_MSG_USERAUTH_REQUEST:
			ssh_got_data(c, &sp);
			ssh_flush_stream(c);
			return 0;
		case SSH_MSG_KEXINIT:
			if (__likely(allow_kex)) {
				c->in_afterkex_upcall = c->in_upcall;
				sshd_kexinit(c);
				return 1;
			}
			/* fall through */
		default:
			debug_warning(c, "Unknown message %d", cmd);
			ssh_flush_stream(c);
			if (__unlikely(ssh_add_byte(c, SSH_MSG_UNIMPLEMENTED))) return 1;
			if (__unlikely(ssh_add_int32(c, c->in_packet_seq - 1))) return 1;
			if (__unlikely(ssh_send_packet(c, 0))) return 1;
			return 0;
	}
}

void ssh_send_disconnect(struct ssh_connection *c, __u32 code, char *string)
{
	CHECKH("ssh_send_disconnect");
	if (__unlikely(ssh_add_byte(c, SSH_MSG_DISCONNECT))) return;
	if (__unlikely(ssh_add_int32(c, code))) return;
	if (__unlikely(ssh_add_string(c, string, strlen(string)))) return;
	if (__unlikely(ssh_add_string(c, "en", 2))) return;
	if (__unlikely(ssh_send_packet(c, 0))) return;
	abort_ssh_connection(c);
}


