#include <STDLIB.H>
#include <STRING.H>
#include <ENDIAN.H>
#include <ARPA/INET.H>
#include <SPAD/DEV.H>
#include <SPAD/SYNC.H>
#include <RESOLV.H>

#include "SSHD.H"
#include "STRUCT.H"
#include "OPS.H"

#define MAX_AUTH_TRIES		20

static void sshd_do_userauth(struct ssh_connection *c);
static void ssh_got_auth_response(struct ssh_connection *c);
static void ssh_auth_failure(struct ssh_connection *c);
static void ssh_auth_pk_ok(struct ssh_connection *c, const __u8 *algorithm, unsigned algorithm_len, const __u8 *key, unsigned key_len);

void sshd_userauth(struct ssh_connection *c)
{
	unsigned sp;
	__u8 cmd;
	const __u8 *service;
	unsigned service_len;
	again:
	sp = 0;
	if (__unlikely(ssh_get_chr_from_connection(c, &sp, &cmd))) return;
	if (__unlikely(cmd != SSH_MSG_SERVICE_REQUEST)) {
		if (ssh_unknown_packet(c, 1)) return;
		goto again;
	}
	if (__unlikely(ssh_get_len_str_from_connection(c, &sp, &service, &service_len))) return;
	ssh_got_data(c, &sp);
	if (__likely(service_len == 12) && __likely(!memcmp(service, "ssh-userauth", 12))) goto userauth;
	ssh_send_disconnect(c, SSH_DISCONNECT_SERVICE_NOT_AVAILABLE, "SERVICE NOT SUPPORTED");
	return;
	userauth:
	if (__unlikely(ssh_add_chr_to_connection(c, SSH_MSG_SERVICE_ACCEPT))) return;
	if (__unlikely(ssh_add_len_str_to_connection(c, service, service_len))) return;
	if (__unlikely(ssh_send_packet(c, 0))) return;

	_snprintf(c->tty_name, TTYSTR_LEN, ".%s.%d", inet_ntoa(c->sin.sin_addr), ntohs(c->sin.sin_port));

	c->auth_count = 0;

	c->in_upcall = sshd_do_userauth;
	sshd_do_userauth(c);
	return;
}

#define AUTH_NONE	0
#define AUTH_PASSWORD	1
#define AUTH_PUBLIC_KEY	2

#define AUTH_STR_LEN	8192
	/* same as TTY_BUFFER_SIZE in TTYSRVR.C */

#define AUTH_METHODS	"publickey,password"

static void uudump(char **ptr, unsigned room, const __u8 *blob, unsigned blob_len)
{
	int l = b64_ntop(blob, blob_len, *ptr, room);
	if (__unlikely(l < 0)) {
		if (__likely(room)) {
			*ptr += room - 1;
			**ptr = 0;
		}
		return;
	}
	*ptr += l;
}

static void sshd_do_userauth(struct ssh_connection *c)
{
	unsigned sp;
	unsigned sp_before_signature;
	__u8 cmd;
	int m;
	const __u8 *user, *service, *method, *password, *algorithm, *key, *signature;
	unsigned user_len, service_len, method_len, password_len, algorithm_len, key_len, signature_len;
	char *auth_data, *p;
	IOCTLRQ io;

	/* zap warnings: */
	sp_before_signature = 0;
	password = NULL, password_len = 0;
	algorithm = NULL, algorithm_len = 0;
	key = NULL, key_len = 0;
	signature = NULL, signature_len = 0;

	again:
	sp = 0;
	if (__unlikely(ssh_get_chr_from_connection(c, &sp, &cmd))) return;
	if (__unlikely(cmd != SSH_MSG_USERAUTH_REQUEST)) {
		if (ssh_unknown_packet(c, 1)) return;
		goto again;
	}
	if (__unlikely(ssh_get_len_str_from_connection(c, &sp, &user, &user_len))) return;
	if (__unlikely(ssh_get_len_str_from_connection(c, &sp, &service, &service_len))) return;
	if (__unlikely(ssh_get_len_str_from_connection(c, &sp, &method, &method_len))) return;
	/*__debug_printf("userauth: user %.*s, service %.*s, method %.*s\n", user_len, user, service_len, service, method_len, method);*/

	if (__unlikely(!user_len)) {
		ssh_send_disconnect(c, SSH_DISCONNECT_ILLEGAL_USER_NAME, "EMPTY USER NAME");
		return;
	}
	if (method_len == 4 && !memcmp(method, "none", 4)) {
		m = AUTH_NONE;
		goto do_auth;
	}
	if (method_len == 8 && !memcmp(method, "password", 8)) {
		__u8 chpass;
		if (__unlikely(ssh_get_chr_from_connection(c, &sp, &chpass))) return;
		if (__unlikely(ssh_get_len_str_from_connection(c, &sp, &password, &password_len))) return;
		m = AUTH_PASSWORD;
		goto do_auth;
	}
	if (method_len == 9 && !memcmp(method, "publickey", 9)) {
		int i;
		__u8 real_auth;
		if (__unlikely(ssh_get_chr_from_connection(c, &sp, &real_auth))) return;
		if (__unlikely(ssh_get_len_str_from_connection(c, &sp, &algorithm, &algorithm_len))) return;
		if (__unlikely(ssh_get_len_str_from_connection(c, &sp, &key, &key_len))) return;
		sp_before_signature = sp;
		if (real_auth) {
			if (__unlikely(ssh_get_len_str_from_connection(c, &sp, &signature, &signature_len))) return;
		} else {
			signature = NULL, signature_len = 0;
		}
		/*__debug_printf("pubkey: real_auth %d, algorithm %.*s, key %u, signature %u.\n", real_auth, algorithm_len, algorithm, key_len, signature_len);*/
		for (i = 0; public_types[i]; i++) if (strlen(public_types[i]->name) == algorithm_len && !memcmp(public_types[i]->name, algorithm, algorithm_len)) {
			if (!real_auth) {
				ssh_got_data(c, &sp);
				ssh_auth_pk_ok(c, algorithm, algorithm_len, key, key_len);
				return;
			} else {
				m = AUTH_PUBLIC_KEY;
				goto do_auth;
			}
		}
		goto fail_auth;
	}

	fail_auth:
	ssh_got_data(c, &sp);
	ssh_flush_stream(c);
	ssh_auth_failure(c);
	return;

	do_auth:
	if (__unlikely(!(auth_data = malloc(TTYSTR_LEN + AUTH_STR_LEN)))) {
		debug_fatal(c, "Can't alloc auth data");
		abort_ssh_connection(c);
		return;
	}
	memcpy(auth_data, c->tty_name, TTYSTR_LEN);
	p = auth_data + TTYSTR_LEN;
#define room	(auth_data + TTYSTR_LEN + AUTH_STR_LEN - p)
	if (__likely(user[0] != '/') && __likely(user[0] != '.')) {
		if (__unlikely(default_user[0] != '/')) _snprintf(p, room, "/"), p = strchr(p, 0);
		_snprintf(p, room, "%s", default_user), p = strchr(p, 0);
		if (__likely(default_user[0]) && __unlikely(default_user[strlen(default_user) - 1] != '/')) _snprintf(p, room, "/"), p = strchr(p, 0);
	}
	_snprintf(p, room, "%.*s", (int)user_len, user);
	for (; *p; p++) if (__unlikely(*p == '.')) *p = '/';
	if (__likely(user[user_len - 1] != '/')) _snprintf(p, room, "/"), p = strchr(p, 0);
	_snprintf(p, room, "\n"), p = strchr(p, 0);

	if (m == AUTH_PASSWORD) {
		_snprintf(p, room, "PASSWORD %.*s\n", (int)password_len, password), p = strchr(p, 0);
	} else if (m == AUTH_PUBLIC_KEY) {
		__u8 *b;
		unsigned bl;
		_snprintf(p, room, "PUBLIC_KEY %.*s ", algorithm_len, algorithm); p = strchr(p, 0);
		uudump(&p, room, key, key_len);
		_snprintf(p, room, " "); p = strchr(p, 0);

		init_str(&b, &bl);
		/* SSH_OLD_SESSIONID */
		if (__unlikely(ssh_add_len_str_to_str(c, &b, &bl, c->session_id, c->session_id_len))) return;
		if (__unlikely(ssh_add_bytes_to_str(c, &b, &bl, c->in_stream + c->in_stream_start, sp_before_signature))) return;
		uudump(&p, room, b, bl);
		free(b);
		_snprintf(p, room, " "); p = strchr(p, 0);
	
		uudump(&p, room, signature, signature_len);
		_snprintf(p, room, "\n"); p = strchr(p, 0);
	}

	_snprintf(p, room, "DONE\n"), p = strchr(p, 0);
	if (__unlikely(room <= 1)) {
		free(auth_data);
		debug_fatal(c, "Auth buffer overflow");
		ssh_send_disconnect(c, SSH_DISCONNECT_AUTH_CANCELLED_BY_USER, "AUTHENTICATION BUFFER OVERFLOW (TOO LONG USERNAME, PASSWORD OR KEY)");
		return;
	}
#undef room

	ssh_got_data(c, &sp);

	io.h = ttyh;
	io.ioctl = IOCTL_TTYSRVR_MAKETTY;
	io.param = 0;
	io.v.ptr = (unsigned long)auth_data;
	io.v.len = p - auth_data;
	io.v.vspace = &KERNEL$VIRTUAL;
	SYNC_IO(&io, KERNEL$IOCTL);
	free(auth_data);
	if (__unlikely(io.status < 0)) {
		debug_fatal(c, "TTY server refused to make a tty: %s", strerror(-io.status));
		ssh_send_disconnect(c, SSH_DISCONNECT_CONNECTION_LOST, "TTYSRVR ERROR");
		return;
	}
	c->flags |= SSHD_HAS_TTY;

	c->out_tty_upcall = ssh_got_auth_response;
	c->out_tty.ioctl = IOCTL_TTYSRVR_GETAUTH;
	c->out_tty.param = 0;
	c->out_tty.v.ptr = (unsigned long)c->tty_name;
	c->out_tty.v.len = TTYSTR_LEN;
	RAISE_SPL(SPL_BOTTOM);
	c->outstanding++;
	CALL_IORQ(&c->out_tty, KERNEL$IOCTL);
	LOWER_SPL(SPL_ZERO);
}

static void ssh_auth_failure(struct ssh_connection *c)
{
	if (__unlikely(++c->auth_count >= MAX_AUTH_TRIES)) {
		ssh_send_disconnect(c, SSH_DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, "TOO MANY AUTHENTICATION TRIES");
		return;
	}
	if (__unlikely(ssh_add_chr_to_connection(c, SSH_MSG_USERAUTH_FAILURE))) return;
	if (__unlikely(ssh_add_len_str_to_connection(c, (__u8 *)AUTH_METHODS, strlen(AUTH_METHODS)))) return;
	if (__unlikely(ssh_add_chr_to_connection(c, 0))) return;
	c->in_upcall = sshd_do_userauth;
	if (__unlikely(ssh_send_packet(c, 0))) return;
	sshd_do_userauth(c);
}

static void ssh_auth_pk_ok(struct ssh_connection *c, const __u8 *algorithm, unsigned algorithm_len, const __u8 *key, unsigned key_len)
{
	if (__unlikely(ssh_add_chr_to_connection(c, SSH_MSG_USERAUTH_PK_OK))) return;
	if (__unlikely(ssh_add_len_str_to_connection(c, algorithm, algorithm_len))) return;
	if (__unlikely(ssh_add_len_str_to_connection(c, key, key_len))) return;
	c->in_upcall = sshd_do_userauth;
	if (__unlikely(ssh_send_packet(c, 0))) return;
	sshd_do_userauth(c);
}

static void ssh_got_auth_response(struct ssh_connection *c)
{
	IOCTLRQ io;
	if (__unlikely(c->out_tty.status != 0)) {
		io.h = ttyh;
		io.ioctl = IOCTL_TTYSRVR_CLOSETTY;
		io.param = 0;
		io.v.ptr = (unsigned long)c->tty_name;
		io.v.len = TTYSTR_LEN;
		io.v.vspace = &KERNEL$VIRTUAL;
		SYNC_IO(&io, KERNEL$IOCTL);
		if (__unlikely(io.status < 0)) {
			debug_fatal(c, "IOCTL_TTYSRVR_CLOSETTY(%s): %ld", c->tty_name, io.status);
			ssh_send_disconnect(c, SSH_DISCONNECT_CONNECTION_LOST, "TTYSRVR CAN'T CLOSE TTY");
			return;
		}
		c->flags &= ~SSHD_HAS_TTY;
		ssh_auth_failure(c);
		return;
	}
	if (__unlikely(ssh_add_chr_to_connection(c, SSH_MSG_USERAUTH_SUCCESS))) return;
	if (__unlikely(ssh_send_packet(c, 0))) return;
	sshd_main_loop(c);
}
