#include "ft_openft.h"
#include "ft_bloom.h"

static inline void bit_set (BloomFilter *bf, int bit)
{
	int offset = bit & bf->mask;
	
	if (bf->count)
	{
		uint8_t *ptr = &bf->count[offset];

		if (*ptr < 255)
			(*ptr)++;
	}

	bf->table[offset >> 3] |= 1 << (offset & 7);
}

static inline void bit_unset (BloomFilter *bf, int bit)
{
	int offset = bit & bf->mask;
	
	if (bf->count)
	{
		uint8_t *ptr = &bf->count[offset];

		if (!*ptr || --(*ptr))
			return;
	}

	bf->table[offset >> 3] &= ~(1 << (offset & 7));
}

static inline BOOL bit_get (BloomFilter *bf, int bit)
{
	int offset = bit & bf->mask;

	return BOOL_EXPR(bf->table[offset >> 3] & (1 << (offset & 7)));
}

static unsigned int bit_count (BloomFilter *bf)
{
	int i;

	uint32_t *ptr = (uint32_t *)bf->table;

	unsigned int count = 0;
	
	int max = 1 << (bf->bits - 5);

	for (i=0; i<max; i++)
	{
		unsigned int b = *(ptr++);

		b = (b & 0x55555555) + ((b & 0xaaaaaaaa) >> 1);
		b = (b & 0x33333333) + ((b & 0xcccccccc) >> 2);
		b = (b & 0x0f0f0f0f) + ((b & 0xf0f0f0f0) >> 4);
		b = (b & 0x00ff00ff) + ((b & 0xff00ff00) >> 8);
		b = (b & 0x0000ffff) + (b >> 16);

		count += b;
	}
	
	return count;
}

/*******************************************************************/

BloomFilter *ft_bloom_new (int bits, int nhash, int keylen, BOOL count)
{
	BloomFilter *bf;

	/* make sure everything's long enough */
	if (bits < 5 || nhash * ((bits+7) & ~7) > keylen)
		return NULL;
	
	bf = MALLOC (sizeof (BloomFilter));
	
	if (!bf)
		return NULL;
	
	bf->table = CALLOC (1, 1 << (bits - 3));

	if (!bf->table)
	{
		free (bf);
		return NULL;
	}

	if (count)
	{
		bf->count = CALLOC (1, 1 << bits);
		
		if (!bf->count)
		{
			free (bf->table);
			free (bf);
			
			return NULL;
		}
	}
	else
		bf->count = NULL;

	bf->bits   = bits;
	bf->mask   = (1 << bits) - 1;
	bf->nhash  = nhash;
	bf->keylen = keylen;

	return bf;
}

void ft_bloom_free (BloomFilter *bf)
{
	if (bf)
		free (bf->table);
	
	free (bf);
}

BloomFilter *ft_bloom_clone (BloomFilter *bf)
{
	BloomFilter *clone;

	clone = ft_bloom_new (bf->bits, bf->nhash, bf->keylen,
			      FALSE);

	if (!clone)
		return NULL;

	memcpy (clone->table, bf->table, 1 << (bf->bits - 3));
	
	return clone;
}

#define GET_HASH(loc)                                            \
	int loc = 0;                                             \
	int left = (bf->bits+7) / 8;                             \
	int shift = 0;                                           \
	                                                         \
	while (left--)                                           \
	{                                                        \
		loc += ((uint8_t *)key)[offset++] << shift;      \
		shift += 8;                                      \
	}

void ft_bloom_add (BloomFilter *bf, void *key)
{
	int n;
	int offset = 0;
	
	for (n=0; n<bf->nhash; n++)
	{
		GET_HASH (loc);
		
		bit_set (bf, loc);
	}
}

void ft_bloom_add_int (BloomFilter *bf, int key)
{
	int n;
	int shift = (bf->bits+7) & ~7;

	for (n=0; n<bf->nhash; n++)
	{
		bit_set (bf, key);
		
		key >>= shift;
	}
}

BOOL ft_bloom_lookup (BloomFilter *bf, void *key)
{
	int n;
	int offset = 0;
	
	for (n=0; n<bf->nhash; n++)
	{
		GET_HASH (loc);
		
		if (!bit_get (bf, loc))
			return FALSE;
	}

	return TRUE;
}

BOOL ft_bloom_lookup_int (BloomFilter *bf, int key)
{
	int n;
	int shift = (bf->bits+7) & ~7;

	for (n=0; n<bf->nhash; n++)
	{
		if (!bit_get (bf, key))
			return FALSE;
		
		key >>= shift;
	}

	return TRUE;
}

BOOL ft_bloom_remove (BloomFilter *bf, void *key)
{
	int n;
	int offset = 0;

	if (!bf->count)
		return FALSE;

	for (n=0; n<bf->nhash; n++)
	{
		GET_HASH (loc);

		bit_unset (bf, loc);
	}

	return TRUE;
}

BOOL ft_bloom_remove_int (BloomFilter *bf, int key)
{
	int n;
	int shift = (bf->bits+7) & ~7;

	if (!bf->count)
		return FALSE;

	for (n=0; n<bf->nhash; n++)
	{
		bit_unset (bf, key);

		key >>= shift;
	}

	return TRUE;
}

BOOL ft_bloom_diff (BloomFilter *new, BloomFilter *old)
{
	int i, max;

	uint32_t *oldptr, *newptr;

	if (new->bits != old->bits)
		return FALSE;

	max = 1 << (new->bits - 5);
	
	oldptr = (uint32_t *)old->table;
	newptr = (uint32_t *)new->table;

	for (i=0; i<max; i++)
		*(oldptr++) ^= *(newptr++);

	return TRUE;
}

void ft_bloom_clear (BloomFilter *bf)
{
	memset (bf->table, 0, 1 << (bf->bits - 3));
	
	if (bf->count)
		memset (bf->count, 0, 1 << bf->bits);
}

/*******************************************************************/

#if 0
int main (void)
{
	BloomFilter *bf = ft_bloom_new (14, 2, 32, TRUE);
	int i;
	int exp_count = 0;
	
	assert (bf);
	
	for (i = 0; i < 100000; i++)
	{
		int r = rand() & 65535;
		int count;

		if (!bit_get (bf, r))
			exp_count++;

		bit_set (bf, r);
		
		count = bit_count (bf);
		assert (count == exp_count);
	}

	return 0;
}
#endif

/*******************************************************************/

#if 1
#define BITS 20
#define KEYS 10000
#include <math.h>
#include <time.h>

int main (void)
{
	BloomFilter *bf = ft_bloom_new (BITS, 1, 32, TRUE);
	int i;
	int keys[KEYS];
	int fpos = 0;
	int count;
	double ratio;

	assert (bf);

	srand (time (NULL));

	for (i = 0; i < KEYS; i++)
	{
		int r = rand();

		keys[i] = r;
		ft_bloom_add_int (bf, r);
	}

	count = bit_count (bf);
	ratio = (double)count/(1<<BITS);
	printf ("size=%d, used=%d, ratio=%f, compressed size~=%.0f\n",
		1<<BITS, count, ratio,
		-(1<<BITS)*
		(((log(ratio)*ratio+log(1-ratio)*(1-ratio))/8/log(2))));

	/* Check the equivalence of ft_bloom_lookup and
	 * ft_bloom_lookup_int. These need to be equivalent (despite
	 * it being marginally less efficient) because otherwise the
	 * filter would have to include a field for how the key had
	 * been specified, which would be evil.
	 */
	for (i=0; i < KEYS; i++)
	{
		uint8_t h[4];
		h[0] = keys[i];
		h[1] = keys[i] >> 8;
		h[2] = keys[i] >> 16;
		h[3] = keys[i] >> 24;

		assert (ft_bloom_lookup_int (bf, keys[i]));
		assert (ft_bloom_lookup (bf, h));
	}

	for (i=0; i < KEYS; i++)
	{
		int j,r;

		do {
			r = rand();

			for (j=0; j< KEYS; j++)
				if (keys[j] == r)
					break;
		} while (j < KEYS);

		fpos += ft_bloom_lookup_int (bf, r);
	}

#if 0
	write (2, bf->table, 1<< (BITS-3));
#endif
	printf ("false positive rate=%f\n", ((double)fpos)/KEYS);

	/* check that removal works */
	for (i = 0; i < KEYS; i++)
	{
		int j;
		assert (ft_bloom_remove_int (bf, keys[i]));
		
		for (j=0; j < KEYS; j++)
			assert (ft_bloom_lookup_int (bf, keys[j]) >=
				(j > i));
	}

	/* should now be empty */
	assert (bit_count (bf) == 0);

	return 0;
}
#endif
