/*	$NetBSD: strlist.c,v 1.3 2023/08/11 07:05:39 mrg Exp $	*/

/*-
 * Copyright (c) 2021 The NetBSD Foundation, Inc.
 * All rights reserved.
 *
 * This code is derived from software contributed to The NetBSD Foundation
 * by Jason R. Thorpe.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

/*
 * strlist --
 *
 *	A set of routines for interacting with IEEE 1275 (OpenFirmware)
 *	style string lists.
 *
 *	An OpenFirmware string list is simply a buffer containing
 *	multiple NUL-terminated strings concatenated together.
 *
 *	So, for example, the a string list consisting of the strings
 *	"foo", "bar", and "baz" would be represented in memory like:
 *
 *		foo\0bar\0baz\0
 */

#include <sys/types.h>

/*
 * Memory allocation wrappers to handle different environments.
 */
#if defined(_KERNEL)
#include <sys/kmem.h>
#include <sys/systm.h>

static void *
strlist_alloc(size_t const size)
{
	return kmem_zalloc(size, KM_SLEEP);
}

static void
strlist_free(void * const v, size_t const size)
{
	kmem_free(v, size);
}
#elif defined(_STANDALONE)
#include <lib/libkern/libkern.h>
#include <lib/libsa/stand.h>

static void *
strlist_alloc(size_t const size)
{
	char *cp = alloc(size);
	if (cp != NULL) {
		memset(cp, 0, size);
	}
	return cp;
}

static void
strlist_free(void * const v, size_t const size)
{
	dealloc(v, size);
}
#else /* user-space */
#include <stdlib.h>
#include <string.h>

extern int pmatch(const char *, const char *, const char **);

static void *
strlist_alloc(size_t const size)
{
	return calloc(1, size);
}

static void
strlist_free(void * const v, size_t const size __unused)
{
	free(v);
}
#endif

#include "strlist.h"

/*
 * strlist_next --
 *
 *	Return a pointer to the next string in the strlist,
 *	or NULL if there are no more strings.
 */
const char *
strlist_next(const char * const sl, size_t const slsize, size_t * const cursorp)
{

	if (sl == NULL || slsize == 0 || cursorp == NULL) {
		return NULL;
	}

	size_t cursor = *cursorp;

	if (cursor >= slsize) {
		/* No more strings in the list. */
		return NULL;
	}

	const char *cp = sl + cursor;
	*cursorp = cursor + strlen(cp) + 1;

	return cp;
}

/*
 * strlist_count --
 *
 *	Return the number of strings in the strlist.
 */
unsigned int
strlist_count(const char *sl, size_t slsize)
{

	if (sl == NULL || slsize == 0) {
		return 0;
	}

	size_t cursize;
	unsigned int count;

	for (count = 0; slsize != 0;
	     count++, sl += cursize, slsize -= cursize) {
		cursize = strlen(sl) + 1;
	}
	return count;
}

/*
 * strlist_string --
 *
 *	Returns the string in the strlist at the specified index.
 *	Returns NULL if the index is beyond the strlist range.
 */
const char *
strlist_string(const char * sl, size_t slsize, unsigned int const idx)
{

	if (sl == NULL || slsize == 0) {
		return NULL;
	}

	size_t cursize;
	unsigned int i;

	for (i = 0; slsize != 0; i++, slsize -= cursize, sl += cursize) {
		cursize = strlen(sl) + 1;
		if (i == idx) {
			return sl;
		}
	}

	return NULL;
}

static bool
match_strcmp(const char * const s1, const char * const s2)
{
	return strcmp(s1, s2) == 0;
}

#if !defined(_STANDALONE)
static bool
match_pmatch(const char * const s1, const char * const s2)
{
	return pmatch(s1, s2, NULL) == 2;
}
#endif /* _STANDALONE */

static bool
strlist_match_internal(const char * const sl, size_t slsize,
    const char * const str, int * const indexp, unsigned int * const countp,
    bool (*match_fn)(const char *, const char *))
{
	const char *cp;
	size_t l;
	int i;
	bool rv = false;

	if (sl == NULL || slsize == 0) {
		return false;
	}

	cp = sl;

	for (i = 0; slsize != 0;
	     l = strlen(cp) + 1, slsize -= l, cp += l, i++) {
		if (rv) {
			/*
			 * We've already matched. We must be
			 * counting to the end.
			 */
			continue;
		}
		if ((*match_fn)(cp, str)) {
			/*
			 * Matched!  Get the index.  If we don't
			 * also want the total count, then get
			 * out early.
			 */
			*indexp = i;
			rv = true;
			if (countp == NULL) {
				break;
			}
		}
	}

	if (countp != NULL) {
		*countp = i;
	}

	return rv;
}

/*
 * strlist_match --
 *
 *	Returns a weighted match value (1 <= match <= sl->count) if the
 *	specified string appears in the strlist.  A match at the
 *	beginning of the list carriest the greatest weight (i.e. sl->count)
 *	and a match at the end of the list carriest the least (i.e. 1).
 *	Returns 0 if there is no match.
 *
 *	This routine operates independently of the cursor used to enumerate
 *	a strlist.
 */
int
strlist_match(const char * const sl, size_t const slsize,
    const char * const str)
{
	unsigned int count;
	int idx = 0 /* XXXGCC 12 */;

	if (strlist_match_internal(sl, slsize, str, &idx, &count,
				   match_strcmp)) {
		return count - idx;
	}
	return 0;
}

#if !defined(_STANDALONE)
/*
 * strlist_pmatch --
 *
 *	Like strlist_match(), but uses pmatch(9) to match the
 *	strings.
 */
int
strlist_pmatch(const char * const sl, size_t const slsize,
    const char * const pattern)
{
	unsigned int count;
	int idx = 0; /* XXXGCC12 */

	if (strlist_match_internal(sl, slsize, pattern, &idx, &count,
				   match_pmatch)) {
		return count - idx;
	}
	return 0;
}
#endif /* _STANDALONE */

/*
 * strlist_index --
 *
 *	Returns the index of the specified string if it appears
 *	in the strlist.  Returns -1 if the string is not found.
 *
 *	This routine operates independently of the cursor used to enumerate
 *	a strlist.
 */
int
strlist_index(const char * const sl, size_t const slsize,
    const char * const str)
{
	int idx;

	if (strlist_match_internal(sl, slsize, str, &idx, NULL,
				   match_strcmp)) {
		return idx;
	}
	return -1;
}

/*
 * strlist_append --
 *
 *	Append the specified string to a mutable strlist.  Turns
 *	true if successful, false upon failure for any reason.
 */
bool
strlist_append(char ** const slp, size_t * const slsizep,
    const char * const str)
{
	size_t const slsize = *slsizep;
	char * const sl = *slp;

	size_t const addsize = strlen(str) + 1;
	size_t const newsize = slsize + addsize;
	char * const newbuf = strlist_alloc(newsize);

	if (newbuf == NULL) {
		return false;
	}

	if (sl != NULL) {
		memcpy(newbuf, sl, slsize);
	}

	memcpy(newbuf + slsize, str, addsize);

	if (sl != NULL) {
		strlist_free(sl, slsize);
	}

	*slp = newbuf;
	*slsizep = newsize;

	return true;
}

#ifdef STRLIST_TEST
/*
 * To build and run the tests:
 *
 * % cc -DSTRLIST_TEST -Os pmatch.c strlist.c
 * % ./a.out
 * Testing basic properties.
 * Testing enumeration.
 * Testing weighted matching.
 * Testing pattern matching.
 * Testing index return.
 * Testing string-at-index.
 * Testing gross blob count.
 * Testing gross blob indexing.
 * Testing creating a strlist.
 * Verifying new strlist.
 * All tests completed successfully.
 * %
 */

static char nice_blob[] = "zero\0one\0two\0three\0four\0five";
static char gross_blob[] = "zero\0\0two\0\0four\0\0";

#include <assert.h>
#include <stdio.h>

int
main(int argc, char *argv[])
{
	const char *sl;
	size_t slsize;
	size_t cursor;
	const char *cp;
	size_t size;

	sl = nice_blob;
	slsize = sizeof(nice_blob);

	printf("Testing basic properties.\n");
	assert(strlist_count(sl, slsize) == 6);

	printf("Testing enumeration.\n");
	cursor = 0;
	assert((cp = strlist_next(sl, slsize, &cursor)) != NULL);
	assert(strcmp(cp, "zero") == 0);

	assert((cp = strlist_next(sl, slsize, &cursor)) != NULL);
	assert(strcmp(cp, "one") == 0);

	assert((cp = strlist_next(sl, slsize, &cursor)) != NULL);
	assert(strcmp(cp, "two") == 0);

	assert((cp = strlist_next(sl, slsize, &cursor)) != NULL);
	assert(strcmp(cp, "three") == 0);

	assert((cp = strlist_next(sl, slsize, &cursor)) != NULL);
	assert(strcmp(cp, "four") == 0);

	assert((cp = strlist_next(sl, slsize, &cursor)) != NULL);
	assert(strcmp(cp, "five") == 0);

	assert((cp = strlist_next(sl, slsize, &cursor)) == NULL);

	printf("Testing weighted matching.\n");
	assert(strlist_match(sl, slsize, "non-existent") == 0);
	assert(strlist_match(sl, slsize, "zero") == 6);
	assert(strlist_match(sl, slsize, "one") == 5);
	assert(strlist_match(sl, slsize, "two") == 4);
	assert(strlist_match(sl, slsize, "three") == 3);
	assert(strlist_match(sl, slsize, "four") == 2);
	assert(strlist_match(sl, slsize, "five") == 1);

	printf("Testing pattern matching.\n");
	assert(strlist_pmatch(sl, slsize, "t?o") == 4);
	assert(strlist_pmatch(sl, slsize, "f[a-o][o-u][a-z]") == 2);

	printf("Testing index return.\n");
	assert(strlist_index(sl, slsize, "non-existent") == -1);
	assert(strlist_index(sl, slsize, "zero") == 0);
	assert(strlist_index(sl, slsize, "one") == 1);
	assert(strlist_index(sl, slsize, "two") == 2);
	assert(strlist_index(sl, slsize, "three") == 3);
	assert(strlist_index(sl, slsize, "four") == 4);
	assert(strlist_index(sl, slsize, "five") == 5);

	printf("Testing string-at-index.\n");
	assert(strcmp(strlist_string(sl, slsize, 0), "zero") == 0);
	assert(strcmp(strlist_string(sl, slsize, 1), "one") == 0);
	assert(strcmp(strlist_string(sl, slsize, 2), "two") == 0);
	assert(strcmp(strlist_string(sl, slsize, 3), "three") == 0);
	assert(strcmp(strlist_string(sl, slsize, 4), "four") == 0);
	assert(strcmp(strlist_string(sl, slsize, 5), "five") == 0);
	assert(strlist_string(sl, slsize, 6) == NULL);

	sl = gross_blob;
	slsize = sizeof(gross_blob);

	printf("Testing gross blob count.\n");
	assert(strlist_count(sl, slsize) == 7);

	printf("Testing gross blob indexing.\n");
	assert(strcmp(strlist_string(sl, slsize, 0), "zero") == 0);
	assert(strcmp(strlist_string(sl, slsize, 1), "") == 0);
	assert(strcmp(strlist_string(sl, slsize, 2), "two") == 0);
	assert(strcmp(strlist_string(sl, slsize, 3), "") == 0);
	assert(strcmp(strlist_string(sl, slsize, 4), "four") == 0);
	assert(strcmp(strlist_string(sl, slsize, 5), "") == 0);
	assert(strcmp(strlist_string(sl, slsize, 6), "") == 0);
	assert(strlist_string(sl, slsize, 7) == NULL);


	printf("Testing creating a strlist.\n");
	char *newsl = NULL;
	size_t newslsize = 0;
	assert(strlist_append(&newsl, &newslsize, "zero"));
	assert(strlist_append(&newsl, &newslsize, "one"));
	assert(strlist_append(&newsl, &newslsize, "two"));
	assert(strlist_append(&newsl, &newslsize, "three"));
	assert(strlist_append(&newsl, &newslsize, "four"));
	assert(strlist_append(&newsl, &newslsize, "five"));

	printf("Verifying new strlist.\n");
	assert(strlist_count(newsl, newslsize) == 6);
	assert(strcmp(strlist_string(newsl, newslsize, 0), "zero") == 0);
	assert(strcmp(strlist_string(newsl, newslsize, 1), "one") == 0);
	assert(strcmp(strlist_string(newsl, newslsize, 2), "two") == 0);
	assert(strcmp(strlist_string(newsl, newslsize, 3), "three") == 0);
	assert(strcmp(strlist_string(newsl, newslsize, 4), "four") == 0);
	assert(strcmp(strlist_string(newsl, newslsize, 5), "five") == 0);
	assert(strlist_string(newsl, newslsize, 6) == NULL);

	/* This should be equivalent to nice_blob. */
	assert(newslsize == sizeof(nice_blob));
	assert(memcmp(newsl, nice_blob, newslsize) == 0);


	printf("All tests completed successfully.\n");
	return 0;
}

#endif /* STRLIST_TEST */