/* vim:tw=78:ts=8:sw=4:set ft=c:  */
/*
    Copyright (C) 2006-2010 Ben Kibbey <bjk@luxsci.net>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02110-1301  USA
*/
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <assert.h>
#include <err.h>

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#ifdef DEBUG
#include <execinfo.h>

#define BACKTRACE(fn)	{ \
			    int n, nptrs; \
			    char **strings; \
			    void *buffer[20]; \
			    nptrs = backtrace(buffer, 20); \
			    strings = backtrace_symbols(buffer, nptrs); \
			    for (n = 0; n < nptrs; n++) \
				fprintf(stderr, "BACKTRACE (%s) %i: %s\n", fn, n, strings[n]); \
			    fprintf(stderr, "\n"); \
			    fflush(stderr); \
	                }
#else
#define BACKTRACE(fn)
#endif

#include "mem.h"

#ifdef USE_PTH_THREADS
#include <pth.h>
#else
#include <pthread.h>
#endif

#include "gettext.h"
#define N_(msgid)	gettext(msgid)

static struct memlist_s *memlist;
#ifdef USE_PTH_THREADS
static pth_mutex_t mem_mutex;
#else
static pthread_mutex_t mem_mutex;
#endif
#ifdef DEBUG
static size_t allocations, deallocations;
#endif

void xmem_init()
{
    static int init;

    if (!init)
#ifdef USE_PTH_THREADS
	pth_mutex_init(&mem_mutex);
#else
    	pthread_mutex_init(&mem_mutex, NULL);
#endif

    init = 1;
}

static int memlist_remove(void *ptr, const char *func)
{
    struct memlist_s *m, *last = NULL, *p;

#ifdef USE_PTH_THREADS
    pth_mutex_acquire(&mem_mutex, FALSE, NULL);
#else
    pthread_mutex_lock(&mem_mutex);
#endif

    for (m = memlist; m; m = m->next) {
	if (m->ptr == ptr) {
#ifdef DEBUG
	    fprintf(stderr, "%s: %p %i\n", func, ptr, m->size);
#endif
	    memset(m->ptr, 0, m->size);
	    free(m->ptr);

	    p = m->next;
	    free(m);
#ifdef DEBUG
	    deallocations++;
#endif

	    if (last)
		last->next = p;
	    else
		memlist = p;

#ifdef USE_PTH_THREADS
	    pth_mutex_release(&mem_mutex);
#else
	    pthread_mutex_unlock(&mem_mutex);
#endif
	    return 1;
	}

	last = m;
    }

#ifdef USE_PTH_THREADS
    pth_mutex_release(&mem_mutex);
#else
    pthread_mutex_unlock(&mem_mutex);
#endif
    return 0;
}

static void memlist_prepend(struct memlist_s *new)
{
#ifdef USE_PTH_THREADS
    pth_mutex_acquire(&mem_mutex, FALSE, NULL);
#else
    pthread_mutex_lock(&mem_mutex);
#endif
#ifdef DEBUG
    allocations++;
#endif
    new->next = memlist;
    memlist = new;
#ifdef USE_PTH_THREADS
    pth_mutex_release(&mem_mutex);
#else
    pthread_mutex_unlock(&mem_mutex);
#endif
}

void xfree(void *ptr)
{
    if (!ptr)
	return;

    if (!memlist_remove(ptr, __FUNCTION__)) {
	warnx(N_("%s: %p not found"), __FUNCTION__, ptr);
	assert(0);
    }
}

void *xmalloc(size_t size)
{
    void *p;
    struct memlist_s *m;

    if ((m = (struct memlist_s *)malloc(sizeof(struct memlist_s))) == NULL)
	return NULL;

    if ((p = (void *)malloc(size)) == NULL) {
	free(m);
	return NULL;
    }

    m->ptr = p;
    m->size = size;
    memlist_prepend(m);
#ifdef DEBUG
    fprintf(stderr, "%s: %p %i\n", __FUNCTION__, p, size);
    BACKTRACE(__FUNCTION__);
#endif
    return m->ptr;
}

void *xcalloc(size_t nmemb, size_t size)
{
    void *p;
    struct memlist_s *m;

    if ((m = (struct memlist_s *)malloc(sizeof(struct memlist_s))) == NULL)
	return NULL;

    if ((p = calloc(nmemb, size)) == NULL) {
	free(m);
	return NULL;
    }

    m->ptr = p;
    m->size = nmemb * size;
    memlist_prepend(m);
#ifdef DEBUG
    fprintf(stderr, "%s: %p %i\n", __FUNCTION__, p, nmemb * size);
    BACKTRACE(__FUNCTION__);
#endif
    return m->ptr;
}

void *xrealloc(void *ptr, size_t size)
{
    void *p;
    struct memlist_s *m;

    if (!size && ptr) {
	xfree(ptr);
	return NULL;
    }

    if (!ptr)
	return xmalloc(size);

#ifdef USE_PTH_THREADS
    pth_mutex_acquire(&mem_mutex, FALSE, NULL);
#else
    pthread_mutex_lock(&mem_mutex);
#endif

    for (m = memlist; m; m = m->next) {
	if (m->ptr == ptr) {
	    if ((p = (void *)malloc(size)) == NULL) {
#ifdef USE_PTH_THREADS
		pth_mutex_release(&mem_mutex);
#else
		pthread_mutex_unlock(&mem_mutex);
#endif
		return NULL;
	    }

	    memcpy(p, m->ptr, size < m->size ? size : m->size);
	    memset(m->ptr, 0, m->size);
	    free(m->ptr);
	    m->ptr = p;
	    m->size = size;
#ifdef DEBUG
	    fprintf(stderr, "%s: %p %i\n", __FUNCTION__, p, size);
	    BACKTRACE(__FUNCTION__);
#endif
#ifdef USE_PTH_THREADS
	    pth_mutex_release(&mem_mutex);
#else
	    pthread_mutex_unlock(&mem_mutex);
#endif
	    return m->ptr;
	}
    }

    warnx(N_("%s: %p not found"), __FUNCTION__, ptr);
#ifdef USE_PTH_THREADS
    pth_mutex_release(&mem_mutex);
#else
    pthread_mutex_unlock(&mem_mutex);
#endif
    assert(0);
    return NULL;
}

char *xstrdup(const char *str)
{
    char *t, *tp;
    size_t len;
    const char *p;

    len = strlen(str) + 1;

    if ((t = (char *)xmalloc(len * sizeof(char))) == NULL)
	return NULL;

    for (p = str, tp = t; *p; p++)
	*tp++ = *p;

    *tp = 0;
#ifdef DEBUG
    fprintf(stderr, "%s: %p\n", __FUNCTION__, t);
    BACKTRACE(__FUNCTION__);
#endif
    return t;
}

void xpanic(void)
{
    struct memlist_s *m;

#ifdef USE_PTH_THREADS
    pth_mutex_acquire(&mem_mutex, FALSE, NULL);
#else
    pthread_mutex_lock(&mem_mutex);
#endif

    for (m = memlist; m; m = memlist)
       xfree(m->ptr);

#ifdef USE_PTH_THREADS
    pth_mutex_release(&mem_mutex);
#else
    pthread_mutex_unlock(&mem_mutex);
#endif
}

#ifdef DEBUG
void xdump(void)
{
    struct memlist_s *m;
    size_t total = 0;

#ifdef USE_PTH_THREADS
    pth_mutex_acquire(&mem_mutex, FALSE, NULL);
#else
    pthread_mutex_lock(&mem_mutex);
#endif

    for (m = memlist; m; m = m->next) {
	fprintf(stderr, "%s: %p %i\n", __FUNCTION__, m->ptr, m->size);
	total += m->size;
    }

    fprintf(stderr, "Total unfreed: %i bytes, allocations: %i, deallocations: %i\n", total,
	    allocations, deallocations);
#ifdef USE_PTH_THREADS
    pth_mutex_release(&mem_mutex);
#else
    pthread_mutex_unlock(&mem_mutex);
#endif
}
#endif
