#include <STDLIB.H>
#include <STRING.H>
#include <ENDIAN.H>
#include <ARPA/INET.H>
#include <SPAD/DEV.H>
#include <SPAD/SYNC.H>

#include "SSHD.H"
#include "STRUCT.H"
#include "OPS.H"

#define MAX_AUTH_TRIES		20

static void sshd_do_userauth(struct ssh_connection *c);
static void ssh_got_auth_response(struct ssh_connection *c);
static void ssh_auth_failure(struct ssh_connection *c);

void sshd_userauth(struct ssh_connection *c)
{
	unsigned sp;
	__u8 cmd;
	char *service;
	unsigned service_len;
	again:
	sp = 0;
	if (__unlikely(ssh_get_byte(c, &sp, &cmd))) return;
	if (__unlikely(cmd != SSH_MSG_SERVICE_REQUEST)) {
		if (ssh_unknown_packet(c, 1)) return;
		goto again;
	}
	if (__unlikely(ssh_get_string(c, &sp, &service, &service_len))) return;
	ssh_got_data(c, &sp);
	if (__likely(service_len == 12) && __likely(!memcmp(service, "ssh-userauth", 12))) goto userauth;
	ssh_send_disconnect(c, SSH_DISCONNECT_SERVICE_NOT_AVAILABLE, "SERVICE NOT SUPPORTED");
	return;
	userauth:
	if (__unlikely(ssh_add_byte(c, SSH_MSG_SERVICE_ACCEPT))) return;
	if (__unlikely(ssh_add_string(c, service, service_len))) return;
	if (__unlikely(ssh_send_packet(c, 0))) return;

	_snprintf(c->tty_name, TTYSTR_LEN, ".%s.%d", inet_ntoa(c->sin.sin_addr), ntohs(c->sin.sin_port));

	c->auth_count = 0;

	c->in_upcall = sshd_do_userauth;
	sshd_do_userauth(c);
	return;
}

#define AUTH_NONE	0
#define AUTH_PASSWORD	1

#define AUTH_STR_LEN	2048
	/* same as TTY_BUFFER_SIZE in TTYSRVR.C */

#define AUTH_METHODS	"password"

static void sshd_do_userauth(struct ssh_connection *c)
{
	unsigned sp;
	__u8 cmd;
	int m;
	char *user, *service, *method, *password;
	unsigned user_len, service_len, method_len, password_len;
	char *auth_data, *p;
	IOCTLRQ io;
	again:
	sp = 0;
	if (__unlikely(ssh_get_byte(c, &sp, &cmd))) return;
	if (__unlikely(cmd != SSH_MSG_USERAUTH_REQUEST)) {
		if (ssh_unknown_packet(c, 1)) return;
		goto again;
	}
	if (__unlikely(ssh_get_string(c, &sp, &user, &user_len))) return;
	if (__unlikely(ssh_get_string(c, &sp, &service, &service_len))) return;
	if (__unlikely(ssh_get_string(c, &sp, &method, &method_len))) return;
	/*__debug_printf("userauth: user %.*s, service %.*s, method %.*s\n", user_len, user, service_len, service, method_len, method);*/

	if (__unlikely(!user_len)) {
		ssh_send_disconnect(c, SSH_DISCONNECT_ILLEGAL_USER_NAME, "EMPTY USER NAME");
		return;
	}
	if (method_len == 4 && !memcmp(method, "none", 4)) {
		m = AUTH_NONE;
		goto do_auth;
	}
	if (method_len == 8 && !memcmp(method, "password", 8)) {
		__u8 chpass;
		if (__unlikely(ssh_get_byte(c, &sp, &chpass))) return;
		if (__unlikely(ssh_get_string(c, &sp, &password, &password_len))) return;
		m = AUTH_PASSWORD;
		goto do_auth;
	}
	ssh_got_data(c, &sp);
	ssh_flush_stream(c);
	ssh_auth_failure(c);
	return;
	do_auth:
	ssh_got_data(c, &sp);
	if (__unlikely(!(auth_data = malloc(TTYSTR_LEN + AUTH_STR_LEN)))) {
		debug_fatal(c, "Can't alloc auth data");
		abort_ssh_connection(c);
		return;
	}
	memcpy(auth_data, c->tty_name, TTYSTR_LEN);
	p = auth_data + TTYSTR_LEN;
#define room	(auth_data + TTYSTR_LEN + AUTH_STR_LEN - p)
	if (__likely(user[0] != '/') && __likely(user[0] != '.')) {
		if (__unlikely(default_user[0] != '/')) _snprintf(p, room, "/"), p += strlen(p);
		_snprintf(p, room, "%s", default_user), p += strlen(p);
		if (__likely(default_user[0]) && __unlikely(default_user[strlen(default_user) - 1] != '/')) _snprintf(p, room, "/"), p += strlen(p);
	}
	_snprintf(p, room, "%.*s", (int)user_len, user);
	for (; *p; p++) if (__unlikely(*p == '.')) *p = '/';
	if (__likely(user[user_len - 1] != '/')) _snprintf(p, room, "/"), p += strlen(p);
	_snprintf(p, room, "\n"), p += strlen(p);
	if (m == AUTH_PASSWORD) {
		_snprintf(p, room, "PASSWORD %.*s\n", (int)password_len, password), p += strlen(p);
	}
	_snprintf(p, room, "DONE\n"), p += strlen(p);
	if (__unlikely(room <= 1)) {
		__slow_free(auth_data);
		debug_fatal(c, "Auth buffer overflow");
		ssh_send_disconnect(c, SSH_DISCONNECT_AUTH_CANCELLED_BY_USER, "AUTHENTICATION BUFFER OVERFLOW (TOO LONG USERNAME OR PASSWORD)");
		return;
	}
#undef room

	io.h = ttyh;
	io.ioctl = IOCTL_TTYSRVR_MAKETTY;
	io.param = 0;
	io.v.ptr = (unsigned long)auth_data;
	io.v.len = p - auth_data;
	io.v.vspace = &KERNEL$VIRTUAL;
	SYNC_IO(&io, KERNEL$IOCTL);
	free(auth_data);
	if (__unlikely(io.status < 0)) {
		debug_fatal(c, "TTY server refused to make a tty: %s", strerror(-io.status));
		ssh_send_disconnect(c, SSH_DISCONNECT_CONNECTION_LOST, "TTYSRVR ERROR");
		return;
	}
	c->flags |= SSHD_HAS_TTY;

	c->out_tty_upcall = ssh_got_auth_response;
	c->out_tty.ioctl = IOCTL_TTYSRVR_GETAUTH;
	c->out_tty.param = 0;
	c->out_tty.v.ptr = (unsigned long)c->tty_name;
	c->out_tty.v.len = TTYSTR_LEN;
	RAISE_SPL(SPL_BOTTOM);
	c->outstanding++;
	CALL_IORQ(&c->out_tty, KERNEL$IOCTL);
	LOWER_SPL(SPL_ZERO);
}

static void ssh_auth_failure(struct ssh_connection *c)
{
	if (__unlikely(++c->auth_count >= MAX_AUTH_TRIES)) {
		ssh_send_disconnect(c, SSH_DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, "TOO MANY AUTHENTICATION TRIES");
		return;
	}
	if (__unlikely(ssh_add_byte(c, SSH_MSG_USERAUTH_FAILURE))) return;
	if (__unlikely(ssh_add_string(c, AUTH_METHODS, strlen(AUTH_METHODS)))) return;
	if (__unlikely(ssh_add_byte(c, 0))) return;
	c->in_upcall = sshd_do_userauth;
	if (__unlikely(ssh_send_packet(c, 0))) return;
	sshd_do_userauth(c);
}

static void ssh_got_auth_response(struct ssh_connection *c)
{
	IOCTLRQ io;
	if (__unlikely(c->out_tty.status != 0)) {
		io.h = ttyh;
		io.ioctl = IOCTL_TTYSRVR_CLOSETTY;
		io.param = 0;
		io.v.ptr = (unsigned long)c->tty_name;
		io.v.len = TTYSTR_LEN;
		io.v.vspace = &KERNEL$VIRTUAL;
		SYNC_IO(&io, KERNEL$IOCTL);
		if (__unlikely(io.status < 0)) {
			debug_fatal(c, "IOCTL_TTYSRVR_CLOSETTY(%s): %ld", c->tty_name, io.status);
			ssh_send_disconnect(c, SSH_DISCONNECT_CONNECTION_LOST, "TTYSRVR CAN'T CLOSE TTY");
			return;
		}
		c->flags &= ~SSHD_HAS_TTY;
		ssh_auth_failure(c);
		return;
	}
	if (__unlikely(ssh_add_byte(c, SSH_MSG_USERAUTH_SUCCESS))) return;
	if (__unlikely(ssh_send_packet(c, 0))) return;
	sshd_main_loop(c);
}
