Adding a remove_at method to the hash table.
[strongswan.git] / src / libstrongswan / utils / hashtable.c
1 /*
2 * Copyright (C) 2008-2010 Tobias Brunner
3 * Hochschule fuer Technik Rapperswil
4 *
5 * This program is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License as published by the
7 * Free Software Foundation; either version 2 of the License, or (at your
8 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
9 *
10 * This program is distributed in the hope that it will be useful, but
11 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
12 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 * for more details.
14 */
15
16 #include <utils/linked_list.h>
17
18 #include "hashtable.h"
19
20 /** The maximum capacity of the hash table (MUST be a power of 2) */
21 #define MAX_CAPACITY (1 << 30)
22
23 typedef struct pair_t pair_t;
24
25 /**
26 * This pair holds a pointer to the key and value it represents.
27 */
28 struct pair_t {
29 /**
30 * Key of a hash table item.
31 */
32 void *key;
33
34 /**
35 * Value of a hash table item.
36 */
37 void *value;
38
39 /**
40 * Cached hash (used in case of a resize).
41 */
42 u_int hash;
43 };
44
45 /**
46 * Creates an empty pair object.
47 */
48 pair_t *pair_create(void *key, void *value, u_int hash)
49 {
50 pair_t *this;
51
52 INIT(this,
53 .key = key,
54 .value = value,
55 .hash = hash,
56 );
57
58 return this;
59 }
60
61 typedef struct private_hashtable_t private_hashtable_t;
62
63 /**
64 * Private data of a hashtable_t object.
65 *
66 */
67 struct private_hashtable_t {
68 /**
69 * Public part of hash table.
70 */
71 hashtable_t public;
72
73 /**
74 * The number of items in the hash table.
75 */
76 u_int count;
77
78 /**
79 * The current capacity of the hash table (always a power of 2).
80 */
81 u_int capacity;
82
83 /**
84 * The current mask to calculate the row index (capacity - 1).
85 */
86 u_int mask;
87
88 /**
89 * The load factor.
90 */
91 float load_factor;
92
93 /**
94 * The actual table.
95 */
96 linked_list_t **table;
97
98 /**
99 * The hashing function.
100 */
101 hashtable_hash_t hash;
102
103 /**
104 * The equality function.
105 */
106 hashtable_equals_t equals;
107 };
108
109 typedef struct private_enumerator_t private_enumerator_t;
110
111 /**
112 * hash table enumerator implementation
113 */
114 struct private_enumerator_t {
115
116 /**
117 * implements enumerator interface
118 */
119 enumerator_t enumerator;
120
121 /**
122 * associated hash table
123 */
124 private_hashtable_t *table;
125
126 /**
127 * current row index
128 */
129 u_int row;
130
131 /**
132 * current pair
133 */
134 pair_t *pair;
135
136 /**
137 * enumerator for the current row
138 */
139 enumerator_t *current;
140 };
141
142 /**
143 * Compare a pair in a list with the given key.
144 */
145 static inline bool pair_equals(pair_t *pair, private_hashtable_t *this, void *key)
146 {
147 return this->equals(key, pair->key);
148 }
149
150 /**
151 * This function returns the next-highest power of two for the given number.
152 * The algorithm works by setting all bits on the right-hand side of the most
153 * significant 1 to 1 and then increments the whole number so it rolls over
154 * to the nearest power of two. Note: returns 0 for n == 0
155 */
156 static u_int get_nearest_powerof2(u_int n)
157 {
158 u_int i;
159
160 --n;
161 for (i = 1; i < sizeof(u_int) * 8; i <<= 1)
162 {
163 n |= n >> i;
164 }
165 return ++n;
166 }
167
168 /**
169 * Init hash table parameters
170 */
171 static void init_hashtable(private_hashtable_t *this, u_int capacity)
172 {
173 capacity = max(1, min(capacity, MAX_CAPACITY));
174 this->capacity = get_nearest_powerof2(capacity);
175 this->mask = this->capacity - 1;
176 this->load_factor = 0.75;
177
178 this->table = calloc(this->capacity, sizeof(linked_list_t*));
179 }
180
181 /**
182 * Double the size of the hash table and rehash all the elements.
183 */
184 static void rehash(private_hashtable_t *this)
185 {
186 linked_list_t **old_table;
187 u_int row, old_capacity;
188
189 if (this->capacity < MAX_CAPACITY)
190 {
191 return;
192 }
193
194 old_capacity = this->capacity;
195 old_table = this->table;
196
197 init_hashtable(this, old_capacity << 1);
198
199 for (row = 0; row < old_capacity; row++)
200 {
201 enumerator_t *enumerator;
202 linked_list_t *list, *new_list;
203 pair_t *pair;
204 u_int new_row;
205
206 list = old_table[row];
207 if (list)
208 {
209 enumerator = list->create_enumerator(list);
210 while (enumerator->enumerate(enumerator, &pair))
211 {
212 new_row = pair->hash & this->mask;
213
214 list->remove_at(list, enumerator);
215 new_list = this->table[new_row];
216 if (!new_list)
217 {
218 new_list = this->table[new_row] = linked_list_create();
219 }
220 new_list->insert_last(new_list, pair);
221 }
222 enumerator->destroy(enumerator);
223 list->destroy(list);
224 }
225 }
226 free(old_table);
227 }
228
229 METHOD(hashtable_t, put, void*,
230 private_hashtable_t *this, void *key, void *value)
231 {
232 void *old_value = NULL;
233 linked_list_t *list;
234 u_int hash;
235 u_int row;
236
237 hash = this->hash(key);
238 row = hash & this->mask;
239 list = this->table[row];
240 if (list)
241 {
242 enumerator_t *enumerator;
243 pair_t *pair;
244
245 enumerator = list->create_enumerator(list);
246 while (enumerator->enumerate(enumerator, &pair))
247 {
248 if (pair_equals(pair, this, key))
249 {
250 old_value = pair->value;
251 pair->value = value;
252 break;
253 }
254 }
255 enumerator->destroy(enumerator);
256 }
257 else
258 {
259 list = this->table[row] = linked_list_create();
260 }
261 if (!old_value)
262 {
263 list->insert_last(list, pair_create(key, value, hash));
264 this->count++;
265 }
266 if (this->count >= this->capacity * this->load_factor)
267 {
268 rehash(this);
269 }
270 return old_value;
271 }
272
273 METHOD(hashtable_t, get, void*,
274 private_hashtable_t *this, void *key)
275 {
276 void *value = NULL;
277 linked_list_t *list;
278 pair_t *pair;
279
280 list = this->table[this->hash(key) & this->mask];
281 if (list)
282 {
283 if (list->find_first(list, (linked_list_match_t)pair_equals,
284 (void**)&pair, this, key) == SUCCESS)
285 {
286 value = pair->value;
287 }
288 }
289 return value;
290 }
291
292 METHOD(hashtable_t, remove_, void*,
293 private_hashtable_t *this, void *key)
294 {
295 void *value = NULL;
296 linked_list_t *list;
297
298 list = this->table[this->hash(key) & this->mask];
299 if (list)
300 {
301 enumerator_t *enumerator;
302 pair_t *pair;
303
304 enumerator = list->create_enumerator(list);
305 while (enumerator->enumerate(enumerator, &pair))
306 {
307 if (pair_equals(pair, this, key))
308 {
309 list->remove_at(list, enumerator);
310 value = pair->value;
311 this->count--;
312 free(pair);
313 break;
314 }
315 }
316 enumerator->destroy(enumerator);
317 }
318 return value;
319 }
320
321 METHOD(hashtable_t, remove_at, void,
322 private_hashtable_t *this, private_enumerator_t *enumerator)
323 {
324 if (enumerator->table == this && enumerator->current)
325 {
326 linked_list_t *list;
327 list = this->table[enumerator->row];
328 if (list)
329 {
330 list->remove_at(list, enumerator->current);
331 free(enumerator->pair);
332 this->count--;
333 }
334 }
335 }
336
337 METHOD(hashtable_t, get_count, u_int,
338 private_hashtable_t *this)
339 {
340 return this->count;
341 }
342
343 METHOD(enumerator_t, enumerate, bool,
344 private_enumerator_t *this, void **key, void **value)
345 {
346 while (this->row < this->table->capacity)
347 {
348 if (this->current)
349 {
350 if (this->current->enumerate(this->current, &this->pair))
351 {
352 if (key)
353 {
354 *key = this->pair->key;
355 }
356 if (value)
357 {
358 *value = this->pair->value;
359 }
360 return TRUE;
361 }
362 this->current->destroy(this->current);
363 this->current = NULL;
364 }
365 else
366 {
367 linked_list_t *list;
368 list = this->table->table[this->row];
369 if (list)
370 {
371 this->current = list->create_enumerator(list);
372 continue;
373 }
374 }
375 this->row++;
376 }
377 return FALSE;
378 }
379
380 METHOD(enumerator_t, enumerator_destroy, void,
381 private_enumerator_t *this)
382 {
383 if (this->current)
384 {
385 this->current->destroy(this->current);
386 }
387 free(this);
388 }
389
390 METHOD(hashtable_t, create_enumerator, enumerator_t*,
391 private_hashtable_t *this)
392 {
393 private_enumerator_t *enumerator;
394
395 INIT(enumerator,
396 .enumerator = {
397 .enumerate = (void*)_enumerate,
398 .destroy = (void*)_enumerator_destroy,
399 },
400 .table = this,
401 );
402
403 return &enumerator->enumerator;
404 }
405
406 METHOD(hashtable_t, destroy, void,
407 private_hashtable_t *this)
408 {
409 linked_list_t *list;
410 u_int row;
411
412 for (row = 0; row < this->capacity; row++)
413 {
414 list = this->table[row];
415 if (list)
416 {
417 list->destroy_function(list, free);
418 }
419 }
420 free(this->table);
421 free(this);
422 }
423
424 /*
425 * Described in header.
426 */
427 hashtable_t *hashtable_create(hashtable_hash_t hash, hashtable_equals_t equals,
428 u_int capacity)
429 {
430 private_hashtable_t *this;
431
432 INIT(this,
433 .public = {
434 .put = _put,
435 .get = _get,
436 .remove = _remove_,
437 .remove_at = (void*)_remove_at,
438 .get_count = _get_count,
439 .create_enumerator = _create_enumerator,
440 .destroy = _destroy,
441 },
442 .hash = hash,
443 .equals = equals,
444 );
445
446 init_hashtable(this, capacity);
447
448 return &this->public;
449 }
450