/** 
 * @file memory_pool.c
 * Contains functions that provide the allocators with memory
 *
 * The memory pool will return properly protected memory for the allocators 
 * to allign and carve up how they see fit. Memory is also returned to the pool 
 * to be cached, freed, etc.
 *
 * Copyright (C) 2001 by Mike Perry.
 * Distributed WITHOUT WARRANTY under the GPL. See COPYING for details.
 */
#include <lib/memory_pool.h>
#include <lib/portability.h>
#include <lib/output.h>
#include <sys/mman.h>
#include <stdlib.h>
#include <fcntl.h>

static void memory_pool_trim(struct nj_memory_pool *memory_pool);

/**
 * Add a new table of memory to the pool, and return a chunk if requested.
 *
 * @param memory_pool The memory pool we're working with
 * @param size The size of the chunk to return from the new table.
 *
 * @NOTES: The reason why this is lumped together is because we cut a 
 * corner by not making the subtables threadsafe. We have to ensure that all
 * accesses to subtables are atomic, and sometimes we need to add memory and 
 * get it in one operation.
 */
static nj_addr_t memory_pool_add_memory_and_get(struct nj_memory_pool *memory_pool, 
		int size)
{
	struct nj_table *new_table;
	nj_addr_t ret = 0;
	int i;

	/* Request a table from the table of tables. This is where our mutex 
	 * protection comes from */	
	new_table = (struct nj_table *)__nj_table_request_top(&memory_pool->memory_tables, 
			sizeof(struct nj_table_light));

	/* Needn't be atomic, because we'll protect it in here */
	/* We also only need to bootstrap this table, since it has no mutexes */
	for (i=0; (NJ_MEM_TABLE_SIZE >> i) >= 2*NJ_PAGE_SIZE; i++)
		if(__nj_table_bootstrap_init(new_table, NULL, NJ_MEM_TABLE_SIZE >> i, 0, 0) == 0)
			break;

	ret = (nj_addr_t)__nj_table_get_chunk(new_table, size);

	__nj_table_release_top(&memory_pool->memory_tables, 
			sizeof(struct nj_table_light));

	return ret;
}

/**
 * Create a new memory pool
 *
 * Initializes the stacks and primary table needed to maintain the memory pools
 *  
 * @param memory_pool The pool to init
 * @returns The beginning of the first table for the allocator to keep track of 
 * which allocs came from system, which from user. Do you believe in the user?
 */
nj_addr_t __nj_memory_pool_bootstrap_init(struct nj_memory_pool *memory_pool)
{
	u_int fencepost_page[4096/sizeof(u_int)];
	char tmpf[] = "./njamd-zeroXXXXXX";
	nj_addr_t table_start;
	int fd;
	int i,j;
	/* This table of tables is atomic, and we use it as a makeshift lock for 
	 * us */
    if(__nj_table_bootstrap_init(&memory_pool->memory_tables, NULL, NJ_MP_TABLE_SIZE, 1, 1))
		__nj_critical_error(__FUNCTION__": table_init");

	/* yes, this loop order is correct */
	for(j = 0; j < NJ_MP_CACHE_BLOCKS; j++)
	{
		for(i=0; i < NJ_MP_NUM_CACHES; i++)
			__nj_stack_bootstrap_init(&memory_pool->block_cache[i][j]);
		memory_pool->first_table[j] = 0;
	}

	table_start = memory_pool_add_memory_and_get(memory_pool, 0);
	
	/* Get a file descriptor for our zero page mapping */
#ifdef HAVE_WORKING_ERRNO
	if((fd = mkstemp(tmpf)) == -1)
	{
		__nj_perror(__FUNCTION__": mkstemp");
	}
#else
	if((fd = open(tmpf, O_TRUNC|O_CREAT|O_RDWR, 0600)) == -1)
	{
		__nj_perror(__FUNCTION__ ": Can't create fencepost file");
		_exit(1);
	}
#endif

	for(i=0; i < sizeof(fencepost_page)/sizeof(u_int); i++)
		fencepost_page[i] = NJ_FENCEPOST;
	
	/* Consider this an easter egg ;) */
	for(i = 0; i < NJ_PAGE_SIZE; i+=sizeof(fencepost_page))
		write(fd, fencepost_page, sizeof(fencepost_page));

	lseek(fd, 0, SEEK_SET);
	
	unlink(tmpf);

	memory_pool->fencepost_fd = fd;

	return table_start;
}

/**
 * Do you believe in the user?
 *
 * @returns A pointer to the next chunk that will be returned. Used to determine
 * which allocs came from system, which from user
 */
nj_addr_t __nj_memory_pool_user_init(struct nj_memory_pool *memory_pool,
		struct nj_libc_syms *syms,
		struct nj_prefs *prefs)
{
	struct nj_table *table = (struct nj_table *)NJ_TABLE_INDEX_TO_PTR(memory_pool->memory_tables, 
			0, struct nj_table_light);
	int i, j;
	nj_addr_t last_sys_alloc;

	__nj_table_user_init(&memory_pool->memory_tables, prefs);

	for(j = 0; j < NJ_MP_CACHE_BLOCKS; j++)
	{
		for(i=0; i < NJ_MP_NUM_CACHES; i++)
			__nj_stack_user_init(&memory_pool->block_cache[i][j]);
	}
	
	memory_pool->libc_malloc = __nj_libc_syms_resolve_libc(syms, "malloc");
	memory_pool->libc_free = __nj_libc_syms_resolve_libc(syms, "free");

	last_sys_alloc = (nj_addr_t)__nj_table_get_chunk(table, 0);

	if(prefs->dyn.alloc_type == NJ_PROT_NONE && !prefs->stat.mutable_alloc)
	{
		memory_pool_trim(memory_pool);
	}
	
	return last_sys_alloc;
}

/**
 * Finish off usage of a memory pool
 *
 * Destroys all the substructures needed to maintain the pool
 *
 * @param memory_pool The pool to drain
 */
void __nj_memory_pool_fini(struct nj_memory_pool *memory_pool)
{
	int i, j;

	__nj_table_fini(&memory_pool->memory_tables);

	for(i=0; i < NJ_MP_NUM_CACHES; i++)
	{
		for(j = 0; j < NJ_MP_CACHE_BLOCKS; j++)
			__nj_stack_fini(&memory_pool->block_cache[i][j]);
	}
}

/**
 * Iterator function to trim the tables when prot=none is set.
 *
 * @param table The table to trim
 * @param ap Empty va_list
 * 
 * @returns Always null, so that the iteration continues for all tables.
 */ 
static void *sub_table_trim(void *table, va_list ap)
{
	__nj_table_trunc(table);
	return NULL;
}
	
/**
 * Trim all the mapped but unused memory from the memory tables.
 * 
 * This is done if the protection scheme is switched to prot=none, where
 * we cease to use our own memory and use libc's malloc
 *
 * @param memory_pool The memory pool to trim
 */
static void memory_pool_trim(struct nj_memory_pool *memory_pool)
{
	/** @FIXME: Technically only the top table needs truncing */
	__nj_table_for_all_entries(&memory_pool->memory_tables, 
			0, sizeof(struct nj_table_light), sub_table_trim);
	
	/* Trunc the top table, and trunc the table of tables */
	__nj_table_trunc(&memory_pool->memory_tables);
}

/**
 * Iterator function called for each table in the group of memory tables
 *
 * @param table The table to check for a chunk
 * @param ap THe va_list that contains the size to get
 *
 * @returns A pointer to the first available chunk
 *
 */
static void *memory_table_get_block(void *table, va_list ap)
{
	/* protected by the memory_pool's lock */
	return __nj_table_get_chunk(table, va_arg(ap, size_t));
}

/**
 * Request block of memory from the memory pool
 * 
 * @param memory_pool The pool to get from
 * @param size The size of the chunk
 * @param alloc_type The type of protection requested
 * 
 */
nj_addr_t __nj_memory_pool_request_block(struct nj_memory_pool *memory_pool, 
		size_t size, struct nj_dynamic_prefs prefs)
{
	int idx = ((size - NJ_PAGE_SIZE) >> NJ_PAGE_SHIFT)-1;
	nj_addr_t ret = 0;
	int tmp;

	if(prefs.alloc_type == NJ_PROT_NONE)
		return (nj_addr_t)memory_pool->libc_malloc(size);

	/* If the index is < the number of cache blocks, then use the cache */
	if((idx < NJ_MP_CACHE_BLOCKS))
	{
		if((ret = (nj_addr_t)__nj_stack_pop(&memory_pool->block_cache[prefs.alloc_type][idx])) == 0)
		{	
			if((ret = (nj_addr_t)__nj_table_for_all_entries(&memory_pool->memory_tables, 
							&memory_pool->first_table[idx],
							sizeof(struct nj_table_light), 
							memory_table_get_block, size)) == 0)
			{
				/** @TODO move to output object */
				if((ret = memory_pool_add_memory_and_get(memory_pool, size)) == 0)
					__nj_critical_error(__FUNCTION__": Out of Memory");
			}

			/* only need to mprotect memory that doesn't come from the free list.
			 * This is the essance of the optimization */
			if(prefs.alloc_type == NJ_PROT_OVER)
				mprotect((void*)ret + size - NJ_PAGE_SIZE, NJ_PAGE_SIZE, __nj_prot);
			else if(prefs.alloc_type == NJ_PROT_UNDER)
				mprotect((void *)ret, NJ_PAGE_SIZE, __nj_prot);
			/* strict underflow needs unprotected memory */
		}
		else /* Got it from the stack list */
			if(prefs.alloc_type == NJ_PROT_UNDER)
				ret -= NJ_PAGE_SIZE;
	}
	else
	{
		/* We don't want our first table to be updated if this size is really 
		 * large, and smaller blocks exist */
		/* Not lock-worthy, cause at worst a race will just make tmp too low */
		tmp = memory_pool->first_table[NJ_MP_CACHE_BLOCKS-1];

		if((ret = (nj_addr_t)__nj_table_for_all_entries(&memory_pool->memory_tables, 
						&tmp, sizeof(struct nj_table_light), 
						memory_table_get_block, size)) == 0)
		{
			if((ret = memory_pool_add_memory_and_get(memory_pool, size)) == 0)
				__nj_critical_error(__FUNCTION__": Out of memory");
		}

		/* only need to mprotect memory that doesn't come from the free list.
		 * This is the essance of the optimization */
		if(prefs.alloc_type == NJ_PROT_OVER)
			mprotect((void *)ret + size - NJ_PAGE_SIZE, NJ_PAGE_SIZE, __nj_prot);
		else if(prefs.alloc_type == NJ_PROT_UNDER)
			mprotect((void *)ret, NJ_PAGE_SIZE, __nj_prot);
		/* strict underflow needs unprotected memory */
	}

	return ret;
}

/**
 * Release block to the memory pool
 *
 * This function is called to 'return' a block of memory to the pool. The memory
 * is either remaped, or cached in one of the free lists.
 * 
 * @param memory_poola The memory pool
 * @param block	The block of memory to release (points to writable memory
 * @param size The size of the allocation
 * @param alloc_type The type of allocator that made this block
 * @param free_type The type of free checking to enact on this block
 *
 * @NOTE don't think about using prefs here, because alloc and free type must 
 * come from two different places (alloc_type is the type when the block was 
 * in existance)
 *
 */
void __nj_memory_pool_release_block(struct nj_memory_pool *memory_pool, 
		nj_addr_t block, size_t size, int alloc_type, int free_type)
{
	
	if(alloc_type == NJ_PROT_NONE)
	{
		/** @FIXME what about insane dlopen? */
		memory_pool->libc_free((void *)block);
		return;	
	}
	
	switch(free_type)
	{
		/* To check for double frees, we map a page of zero mem to clue us in */
		case NJ_CHK_FREE_ERROR:
			if(mmap((void *)block, NJ_PAGE_SIZE, __nj_prot, MAP_FIXED | MAP_SHARED,
						memory_pool->fencepost_fd, 0) == (void *)-1)
				__nj_critical_error(__FUNCTION__"/mremap");
			
			if(mmap((void *)block+NJ_PAGE_SIZE, size-NJ_PAGE_SIZE, __nj_prot, 
						MAP_FIXED | MAP_PRIVATE | MAP_ANON, __nj_anonfd, 0) == (void *)-1)
				__nj_critical_error(__FUNCTION__"/mremap");
			break;

		case NJ_CHK_FREE_SEGV:
			if(mmap((void *)block, size, __nj_prot, MAP_FIXED | MAP_PRIVATE | MAP_ANON,
						__nj_anonfd, 0) == (void *)-1)
				__nj_critical_error(__FUNCTION__"/mremap");
			break;

		case NJ_CHK_FREE_NONE:
			{
				int idx = ((size - NJ_PAGE_SIZE) >> NJ_PAGE_SHIFT)-1;

				if(idx < NJ_MP_CACHE_BLOCKS)
				{
					if(alloc_type == NJ_PROT_SUNDER)
						mprotect((void *)block, NJ_PAGE_SIZE, PROT_READ|PROT_WRITE|PROT_EXEC);
					__nj_stack_push(&memory_pool->block_cache[alloc_type][idx], (struct nj_stack_item *)(block + (alloc_type == NJ_PROT_UNDER ? NJ_PAGE_SIZE : 0)));
				}
				else
					/* Ouch.. forgetting to munmap here would have been bad */
					if(munmap((void *)block, size) == -1)
						__nj_critical_error(__FUNCTION__"/munmap");
			}
			break;
		case NJ_CHK_FREE_NOFREE:
			/* Do nothing */
			break;

		default:
			__nj_eprintf("Unknown free checking option %d\n", free_type);
			if(mmap((void *)block, size, __nj_prot, MAP_FIXED | MAP_PRIVATE | MAP_ANON,
						__nj_anonfd, 0) == (void *)-1)
				__nj_critical_error(__FUNCTION__"/mremap");
			break;
	}

}

// vim:ts=4
