/* SPDX-License-Identifier: GPL-2.0-only OR GPL-3.0-only */
/* Copyright (c) 2025 Brett A C Sheffield <bacs@librecast.net> */

#include "test.h"
#include "testnet.h"
#include <librecast/key.h>
#include <librecast/net.h>
#include <pthread.h>
#include <semaphore.h>

#define TEST_NAME "restricted channels"

#if defined(HAVE_LIBLCRQ) && (HAVE_LIBSODIUM)
#define WAITS 4 /* timeout seconds */
#define TOKEN_SECONDS 60
#define TEST_SIZE 1024
#define SPEED_LIMIT 1024 * 1024 * 32 /* 32Mbps */

enum {
	RECV,
	SEND
};

static sem_t sem;
static sem_t sem_recv;
static uint8_t *senderpubkey;
static uint8_t *signerpubkey;
static lc_token_t *sendertoken;
static char send_buf[TEST_SIZE];
static char fake_buf[TEST_SIZE];
static lc_token_t token;
static lc_keypair_t signerkey; /* needs to be global so we can create bad tokens */
static lc_key_t ssk;
static lc_key_t bad_ssk;
static unsigned char *enckey; /* symmetric encryption key */
static unsigned int ifidx;
static int multi;
static uint8_t capbits = 42;

static void *thread_recv(void *arg)
{
	lc_channel_t *chan = arg;
	lc_channel_t *dst;
	lc_keyring_t keyring = {0};
	lc_filter_t filter;
	char buf[BUFSIZ];
	ssize_t byt;
	int rc;

	/* set filter on channel */
	rc = lc_keyring_init(&keyring, 1);
	rc = lc_keyring_add(&keyring, signerpubkey);
	test_assert(rc == 0, "lc_keyring_add()");
	filter.keyring = &keyring;
	filter.capbits = capbits;
	lc_channel_filter_set(chan, &filter);

	/* We recv once and exit. If the filter is working, we only recv the
	 * good (last) packet. If not, we have garbage */
	memset(buf, 0, sizeof buf);
	sem_post(&sem_recv); /* tell sender we're ready */
	if (multi) byt = lc_socket_multi_recv(lc_channel_socket(chan), buf, TEST_SIZE, 0, &dst);
	else byt = lc_channel_recv(chan, buf, TEST_SIZE, 0);
	if (byt == -1) perror("lc_channel_recv/lc_socket_multi_recv");
	test_assert(byt == TEST_SIZE, "lc_channel_recv() returned %zi", byt);
	test_assert(memcmp(send_buf, buf, TEST_SIZE) == 0, "received data matches");
	if (multi) test_assert(dst == chan, "destination channel matches channel pointer");

	lc_keyring_free(&keyring);
	sem_post(&sem); /* tell ctrl thread we are finished */

	return NULL;
}

static void *thread_send(void *arg)
{
	lc_channel_t *chan = arg;
	ssize_t byt;

	lc_channel_setkey(chan, &ssk, LC_CODE_SIGN);

	/* send fake data without token */
	byt = lc_channel_send(chan, fake_buf, sizeof fake_buf, 0);
	test_assert(byt == TEST_SIZE, "lc_channel_send() returned %zi", byt);

	/* set channel token */
	lc_channel_setkey(chan, &ssk, LC_CODE_SIGN); /* reset correct signing key */
	lc_channel_token_set(chan, &token);

	sem_wait(&sem_recv); /* wait for receiver to be ready */
	byt = lc_channel_send(chan, send_buf, sizeof send_buf, 0);
	test_assert(byt == (ssize_t)TEST_SIZE, "lc_channel_send() returned %zi / %zi", byt, TEST_SIZE);

	return NULL;
}

static int create_token(lc_keypair_t *signerkey, lc_channel_t *chan)
{
	/* create token */
	struct timespec expires;
	int rc;

	clock_gettime(CLOCK_REALTIME, &expires);
	expires.tv_sec += TOKEN_SECONDS;
	rc = lc_token_new(&token, signerkey, senderpubkey, chan, capbits, TOKEN_SECONDS);

	/* verify token */
	if (!test_assert(rc == 0, "lc_sign_keypair() - return value = %i", rc))
		return -1;
	if (!test_assert(token.version == 0xff, "token: version set"))
		return -1;
	if (!test_assert(token.capbits == 42, "token: capbits set"))
		return -1;
	if (!test_assert(token.expires >= (uint64_t)expires.tv_sec, "token: expires set"))
		return -1;
	rc = memcmp(token.channel, lc_channel_get_hash(chan), sizeof token.channel);
	if (!test_assert(rc == 0, "token: channel set"))
		return -1;
	rc = memcmp(token.signkey, signerkey->pk, sizeof token.signkey);
	if (!test_assert(rc == 0, "token: signerkey set"))
		return -1;
	rc = memcmp(token.bearkey, senderpubkey, sizeof token.bearkey);
	if (!test_assert(rc == 0, "token: bearer key set"))
		return -1;
	rc = crypto_sign_verify_detached(token.sig, (unsigned char *)&token, 64, signerkey->pk);
	if (!test_assert(rc == 0, "token: signature OK"))
		return -1;

	/* publish token */
	sendertoken = &token;
	return test_status;
}

static int generate_keys(lc_keypair_t *signerkey, lc_keypair_t *senderkey)
{
	int rc;

	/* first, create signing keypair */
	rc = lc_keypair_new(signerkey, LC_KEY_SIG);
	test_assert(rc == 0, "lc_keypair_new() - create signing keypair");
	signerpubkey = signerkey->pk; /* publish signer public key */

	/* create sender keypair */
	rc = lc_keypair_new(senderkey, LC_KEY_SIG);
	test_assert(rc == 0, "lc_keypair_new() - create sender keypair");
	senderpubkey = senderkey->pk; /* publish sender public key */

	/* set secret signing key (sender) and bad (unauthorized) key */
	ssk.key = senderkey->sk;
	ssk.keylen = crypto_sign_SECRETKEYBYTES;
	bad_ssk.key = signerkey->sk;
	bad_ssk.keylen = crypto_sign_SECRETKEYBYTES;

	/* generate symmetric encryption key */
	if (enckey) crypto_secretbox_keygen(enckey);

	return test_status;
}

static void print_test_encodings(lc_socktype_t socktype, int encoding)
{
	test_log("--------------------------------------------------------------------------------\n");
	test_log("socket type: ");
	if (socktype & LC_SOCK_PAIR) test_log("socketpair\n");
	else test_log("IPv6\n");
	test_log("encryption: ");
	if (encoding & LC_CODE_SYMM) test_log("symmetric encryption\n");
	else test_log("none\n");
	test_log("encodings:");
	if (encoding & LC_CODE_FEC_RQ) test_log(" RaptorQ");
	if (encoding & LC_CODE_FEC_OTI) test_log(" OTI");
	test_log("\n");
	test_log("receiver API call: ");
	if (multi) test_log("lc_socket_multi_recv()\n");
	else test_log("lc_channel_recv()\n");
	test_log("--------------------------------------------------------------------------------\n");
}

static int test_restricted(lc_socktype_t socktype, int encoding)
{
	pthread_t tid[2];
	lc_keypair_t senderkey;
	lc_ctx_t *lctx;
	lc_socket_t *sock[2];
	lc_channel_t *chan[2];
	int rc;

	print_test_encodings(socktype, encoding);

	/* prepare context, sockets, channels */
	lctx = lc_ctx_new();
	if (!test_assert(lctx != NULL, "lc_ctx_new()")) return test_status;
	if (socktype == LC_SOCK_IN6) {
		for (int i = 0; i < 2; i++) {
			sock[i] = lc_socket_new(lctx);
			if (!test_assert(sock[i] != NULL, "%i: lc_socket_new()", i)) goto free_lctx;
			rc = lc_socket_bind(sock[i], ifidx);
			if (rc == -1) perror("lc_socket_bind");
			if (!test_assert(rc == 0, "%i: lc_socket_bind() ifx = %u", i, ifidx)) goto free_lctx;
		}
		rc = lc_socket_loop(sock[SEND], 1);
		if (!test_assert(rc == 0, "lc_socket_loop()")) goto free_lctx;
	}
	else if (socktype == LC_SOCK_PAIR) {
		rc = lc_socketpair(lctx, sock);
		if (!test_assert(rc == 0, "lc_socketpair()")) goto free_lctx;
	}
	else goto free_lctx;
	for (int i = 0; i < 2; i++) {
		chan[i] = lc_channel_new(lctx, TEST_NAME);
		if (!test_assert(chan[i] != NULL, "%i: lc_channel_new()", i)) goto free_lctx;
		test_log("chan[%i]: %p\n", i, (void *)chan[i]);
		rc = lc_channel_bind(sock[i], chan[i]);
		if (!test_assert(rc == 0, "lc_channel_bind()")) goto free_lctx;
		if (encoding) lc_channel_coding_set(chan[i], encoding);
	}
	rc = lc_channel_join(chan[RECV]);
	if (!test_assert(rc == 0, "lc_channel_join()")) goto free_lctx;
	if (encoding & LC_CODE_FEC_RQ) lc_channel_rq_overhead(chan[SEND], RQ_OVERHEAD * 2);
	lc_channel_ratelimit(chan[SEND], SPEED_LIMIT, 0);

	/* generate keys and tokens */
	if (encoding & LC_CODE_SYMM) {
		enckey = malloc(crypto_secretbox_KEYBYTES);
		if (!enckey) goto free_lctx;
		for (int i = 0; i < 2; i++) {
			lc_channel_set_sym_key(chan[i], enckey, crypto_secretbox_KEYBYTES);
		}
	}
	if (generate_keys(&signerkey, &senderkey)) goto free_lctx;
	if (create_token(&signerkey, chan[0])) goto free_lctx;

	/* generate test data */
	arc4random_buf(send_buf, sizeof send_buf);
	arc4random_buf(fake_buf, sizeof fake_buf);

	/* start threads */
	void *(*thread_f[2])(void *) = { &thread_recv, &thread_send };
	int threads = 0;
	rc = sem_init(&sem, 0, 0);
	if (!test_assert(rc == 0, "sem_init()")) goto free_lctx;
	rc = sem_init(&sem_recv, 0, 0);
	if (!test_assert(rc == 0, "sem_init()")) goto free_sem;
	for (int i = 0; i < 2; i++) {
		rc = pthread_create(&tid[i], NULL, thread_f[i], chan[i]);
		if (!test_assert(rc == 0, "%i: pthread_create", i)) goto join_threads;
		threads++;
	}

	/* timeout */
	struct timespec ts;
	if (!test_assert(!clock_gettime(CLOCK_REALTIME, &ts), "clock_gettime()")) goto join_threads;
	ts.tv_sec += WAITS;
	test_assert(!sem_timedwait(&sem, &ts), "timeout");

join_threads:
	for (int i = 0; i < threads; i++) {
		pthread_cancel(tid[i]);
		pthread_join(tid[i], NULL);
	}
	sem_destroy(&sem_recv);
free_sem:
	sem_destroy(&sem);
free_lctx:
	free(enckey); enckey = NULL;
	lc_ctx_free(lctx);
	return test_status;
}
#endif

int main(void)
{
	char name[] = TEST_NAME;
#if defined(HAVE_LIBLCRQ) && (HAVE_LIBSODIUM)
	test_name(name);
	test_require_net(TEST_NET_BASIC);

	ifidx = get_multicast_if();
	if (!test_assert(ifidx > 0, "get_multicast_if()")) return test_status;

	/* test different socket types and symmetric encryption with different encodings */
	for (lc_socktype_t i = 0; i <= LC_SOCK_PAIR; i++) {
		for (lc_coding_t e = 0; e <= LC_CODE_SYMM; e++) {
			/* multi = 0 => use lc_channel_recv()
			 * multi = 1 => use lc_socket_multi_recv() */
			for (multi = 0; multi < 2; multi++) {
				if (test_restricted(i, e | LC_CODE_NONE)) return test_status;
				if (test_restricted(i, e | LC_CODE_FEC_RQ)) return test_status;
				if (test_restricted(i, e | LC_CODE_FEC_RQ | LC_CODE_FEC_OTI))
					return test_status;
				if (test_restricted(i, LC_CODE_FEC_RQ | LC_CODE_FEC_OTI)) return test_status;
			}
			}
	}

	return test_status;
#else
	return test_skip("%s - requires libsodium", name);
#endif
}
