Added missing include in mysql plugin.
[strongswan.git] / src / libstrongswan / plugins / mysql / mysql_database.c
1 /*
2 * Copyright (C) 2007 Martin Willi
3 * Hochschule fuer Technik Rapperswil
4 *
5 * This program is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License as published by the
7 * Free Software Foundation; either version 2 of the License, or (at your
8 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
9 *
10 * This program is distributed in the hope that it will be useful, but
11 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
12 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 * for more details.
14 */
15
16 #define _GNU_SOURCE
17 #include <string.h>
18 #include <mysql.h>
19
20 #include "mysql_database.h"
21
22 #include <debug.h>
23 #include <chunk.h>
24 #include <threading/thread_value.h>
25 #include <threading/mutex.h>
26 #include <utils/linked_list.h>
27
28 /* Older mysql.h headers do not define it, but we need it. It is not returned
29 * in in MySQL 4 by default, but by MySQL 5. To avoid this problem, we catch
30 * it in all cases. */
31 #ifndef MYSQL_DATA_TRUNCATED
32 #define MYSQL_DATA_TRUNCATED 101
33 #endif
34
35 typedef struct private_mysql_database_t private_mysql_database_t;
36
37 /**
38 * private data of mysql_database
39 */
40 struct private_mysql_database_t {
41
42 /**
43 * public functions
44 */
45 mysql_database_t public;
46
47 /**
48 * connection pool, contains conn_t
49 */
50 linked_list_t *pool;
51
52 /**
53 * mutex to lock pool
54 */
55 mutex_t *mutex;
56
57 /**
58 * hostname to connect to
59 */
60 char *host;
61
62 /**
63 * username to use
64 */
65 char *username;
66
67 /**
68 * password
69 */
70 char *password;
71
72 /**
73 * database name
74 */
75 char *database;
76
77 /**
78 * tcp port
79 */
80 int port;
81 };
82
83 typedef struct conn_t conn_t;
84
85 /**
86 * connection pool entry
87 */
88 struct conn_t {
89
90 /**
91 * MySQL database connection
92 */
93 MYSQL *mysql;
94
95 /**
96 * connection in use?
97 */
98 bool in_use;
99 };
100
101 /**
102 * Release a mysql connection
103 */
104 static void conn_release(conn_t *conn)
105 {
106 conn->in_use = FALSE;
107 }
108
109 /**
110 * thread specific initialization flag
111 */
112 thread_value_t *initialized;
113
114 /**
115 * Initialize a thread for mysql usage
116 */
117 static void thread_initialize()
118 {
119 if (initialized->get(initialized) == NULL)
120 {
121 initialized->set(initialized, (void*)TRUE);
122 mysql_thread_init();
123 }
124 }
125
126 /**
127 * mysql library initialization function
128 */
129 bool mysql_database_init()
130 {
131 if (mysql_library_init(0, NULL, NULL))
132 {
133 return FALSE;
134 }
135 initialized = thread_value_create((thread_cleanup_t)mysql_thread_end);
136 return TRUE;
137 }
138
139 /**
140 * mysql library cleanup function
141 */
142 void mysql_database_deinit()
143 {
144 initialized->destroy(initialized);
145 mysql_thread_end();
146 /* mysql_library_end(); would be the clean way, however, it hangs... */
147 }
148
149 /**
150 * Destroy a mysql connection
151 */
152 static void conn_destroy(conn_t *this)
153 {
154 mysql_close(this->mysql);
155 free(this);
156 }
157
158 /**
159 * Acquire/Reuse a mysql connection
160 */
161 static conn_t *conn_get(private_mysql_database_t *this)
162 {
163 conn_t *current, *found = NULL;
164 enumerator_t *enumerator;
165
166 thread_initialize();
167
168 while (TRUE)
169 {
170 this->mutex->lock(this->mutex);
171 enumerator = this->pool->create_enumerator(this->pool);
172 while (enumerator->enumerate(enumerator, &current))
173 {
174 if (!current->in_use)
175 {
176 found = current;
177 found->in_use = TRUE;
178 break;
179 }
180 }
181 enumerator->destroy(enumerator);
182 this->mutex->unlock(this->mutex);
183 if (found)
184 { /* check connection if found, release if ping fails */
185 if (mysql_ping(found->mysql) == 0)
186 {
187 break;
188 }
189 this->mutex->lock(this->mutex);
190 this->pool->remove(this->pool, found, NULL);
191 this->mutex->unlock(this->mutex);
192 conn_destroy(found);
193 found = NULL;
194 continue;
195 }
196 break;
197 }
198 if (found == NULL)
199 {
200 found = malloc_thing(conn_t);
201 found->in_use = TRUE;
202 found->mysql = mysql_init(NULL);
203 if (!mysql_real_connect(found->mysql, this->host, this->username,
204 this->password, this->database, this->port,
205 NULL, 0))
206 {
207 DBG1(DBG_LIB, "connecting to mysql://%s:***@%s:%d/%s failed: %s",
208 this->username, this->host, this->port, this->database,
209 mysql_error(found->mysql));
210 conn_destroy(found);
211 found = NULL;
212 }
213 else
214 {
215 this->mutex->lock(this->mutex);
216 this->pool->insert_last(this->pool, found);
217 DBG2(DBG_LIB, "increased MySQL connection pool size to %d",
218 this->pool->get_count(this->pool));
219 this->mutex->unlock(this->mutex);
220 }
221 }
222 return found;
223 }
224
225 /**
226 * Create and run a MySQL stmt using a sql string and args
227 */
228 static MYSQL_STMT* run(MYSQL *mysql, char *sql, va_list *args)
229 {
230 MYSQL_STMT *stmt;
231 int params;
232
233 stmt = mysql_stmt_init(mysql);
234 if (stmt == NULL)
235 {
236 DBG1(DBG_LIB, "creating MySQL statement failed: %s",
237 mysql_error(mysql));
238 return NULL;
239 }
240 if (mysql_stmt_prepare(stmt, sql, strlen(sql)))
241 {
242 DBG1(DBG_LIB, "preparing MySQL statement failed: %s",
243 mysql_stmt_error(stmt));
244 mysql_stmt_close(stmt);
245 return NULL;
246 }
247 params = mysql_stmt_param_count(stmt);
248 if (params > 0)
249 {
250 int i;
251 MYSQL_BIND *bind;
252
253 bind = alloca(sizeof(MYSQL_BIND) * params);
254 memset(bind, 0, sizeof(MYSQL_BIND) * params);
255
256 for (i = 0; i < params; i++)
257 {
258 switch (va_arg(*args, db_type_t))
259 {
260 case DB_INT:
261 {
262 bind[i].buffer_type = MYSQL_TYPE_LONG;
263 bind[i].buffer = (char*)alloca(sizeof(int));
264 *(int*)bind[i].buffer = va_arg(*args, int);
265 bind[i].buffer_length = sizeof(int);
266 break;
267 }
268 case DB_UINT:
269 {
270 bind[i].buffer_type = MYSQL_TYPE_LONG;
271 bind[i].buffer = (char*)alloca(sizeof(u_int));
272 *(u_int*)bind[i].buffer = va_arg(*args, u_int);
273 bind[i].buffer_length = sizeof(u_int);
274 bind[i].is_unsigned = TRUE;
275 break;
276 }
277 case DB_TEXT:
278 {
279 bind[i].buffer_type = MYSQL_TYPE_STRING;;
280 bind[i].buffer = va_arg(*args, char*);
281 if (bind[i].buffer)
282 {
283 bind[i].buffer_length = strlen(bind[i].buffer);
284 }
285 break;
286 }
287 case DB_BLOB:
288 {
289 chunk_t chunk = va_arg(*args, chunk_t);
290 bind[i].buffer_type = MYSQL_TYPE_BLOB;
291 bind[i].buffer = chunk.ptr;
292 bind[i].buffer_length = chunk.len;
293 break;
294 }
295 case DB_DOUBLE:
296 {
297 bind[i].buffer_type = MYSQL_TYPE_DOUBLE;
298 bind[i].buffer = (char*)alloca(sizeof(double));
299 *(double*)bind[i].buffer = va_arg(*args, double);
300 bind[i].buffer_length = sizeof(double);
301 break;
302 }
303 case DB_NULL:
304 {
305 bind[i].buffer_type = MYSQL_TYPE_NULL;
306 break;
307 }
308 default:
309 DBG1(DBG_LIB, "invalid data type supplied");
310 mysql_stmt_close(stmt);
311 return NULL;
312 }
313 }
314 if (mysql_stmt_bind_param(stmt, bind))
315 {
316 DBG1(DBG_LIB, "binding MySQL param failed: %s",
317 mysql_stmt_error(stmt));
318 mysql_stmt_close(stmt);
319 return NULL;
320 }
321 }
322 if (mysql_stmt_execute(stmt))
323 {
324 DBG1(DBG_LIB, "executing MySQL statement failed: %s",
325 mysql_stmt_error(stmt));
326 mysql_stmt_close(stmt);
327 return NULL;
328 }
329 return stmt;
330 }
331
332 typedef struct {
333 /** implements enumerator_t */
334 enumerator_t public;
335 /** associated MySQL statement */
336 MYSQL_STMT *stmt;
337 /** result bindings */
338 MYSQL_BIND *bind;
339 /** pooled connection handle */
340 conn_t *conn;
341 /** value for INT, UINT, double */
342 union {
343 void *p_void;;
344 int *p_int;
345 u_int *p_uint;
346 double *p_double;
347 } val;
348 /* length for TEXT and BLOB */
349 unsigned long *length;
350 } mysql_enumerator_t;
351
352 /**
353 * create a mysql enumerator
354 */
355 static void mysql_enumerator_destroy(mysql_enumerator_t *this)
356 {
357 int columns, i;
358
359 columns = mysql_stmt_field_count(this->stmt);
360
361 for (i = 0; i < columns; i++)
362 {
363 switch (this->bind[i].buffer_type)
364 {
365 case MYSQL_TYPE_STRING:
366 case MYSQL_TYPE_BLOB:
367 {
368 free(this->bind[i].buffer);
369 break;
370 }
371 default:
372 break;
373 }
374 }
375 mysql_stmt_close(this->stmt);
376 conn_release(this->conn);
377 free(this->bind);
378 free(this->val.p_void);
379 free(this->length);
380 free(this);
381 }
382
383 /**
384 * Implementation of database.query().enumerate
385 */
386 static bool mysql_enumerator_enumerate(mysql_enumerator_t *this, ...)
387 {
388 int i, columns;
389 va_list args;
390
391 columns = mysql_stmt_field_count(this->stmt);
392
393 /* free/reset data set of previous call */
394 for (i = 0; i < columns; i++)
395 {
396 switch (this->bind[i].buffer_type)
397 {
398 case MYSQL_TYPE_STRING:
399 case MYSQL_TYPE_BLOB:
400 {
401 free(this->bind[i].buffer);
402 this->bind[i].buffer = NULL;
403 this->bind[i].buffer_length = 0;
404 this->bind[i].length = &this->length[i];
405 this->length[i] = 0;
406 break;
407 }
408 default:
409 break;
410 }
411 }
412
413 switch (mysql_stmt_fetch(this->stmt))
414 {
415 case 0:
416 case MYSQL_DATA_TRUNCATED:
417 break;
418 case MYSQL_NO_DATA:
419 return FALSE;
420 default:
421 DBG1(DBG_LIB, "fetching MySQL row failed: %s",
422 mysql_stmt_error(this->stmt));
423 return FALSE;
424 }
425
426 va_start(args, this);
427 for (i = 0; i < columns; i++)
428 {
429 switch (this->bind[i].buffer_type)
430 {
431 case MYSQL_TYPE_LONG:
432 {
433 if (this->bind[i].is_unsigned)
434 {
435 u_int *value = va_arg(args, u_int*);
436 *value = this->val.p_uint[i];
437 }
438 else
439 {
440 int *value = va_arg(args, int*);
441 *value = this->val.p_int[i];
442 }
443 break;
444 }
445 case MYSQL_TYPE_STRING:
446 {
447 char **value = va_arg(args, char**);
448 this->bind[i].buffer = malloc(this->length[i]+1);
449 this->bind[i].buffer_length = this->length[i];
450 *value = this->bind[i].buffer;
451 mysql_stmt_fetch_column(this->stmt, &this->bind[i], i, 0);
452 ((char*)this->bind[i].buffer)[this->length[i]] = '\0';
453 break;
454 }
455 case MYSQL_TYPE_BLOB:
456 {
457 chunk_t *value = va_arg(args, chunk_t*);
458 this->bind[i].buffer = malloc(this->length[i]);
459 this->bind[i].buffer_length = this->length[i];
460 value->ptr = this->bind[i].buffer;
461 value->len = this->length[i];
462 mysql_stmt_fetch_column(this->stmt, &this->bind[i], i, 0);
463 break;
464 }
465 case MYSQL_TYPE_DOUBLE:
466 {
467 double *value = va_arg(args, double*);
468 *value = this->val.p_double[i];
469 break;
470 }
471 default:
472 break;
473 }
474 }
475 return TRUE;
476 }
477
478 METHOD(database_t, query, enumerator_t*,
479 private_mysql_database_t *this, char *sql, ...)
480 {
481 MYSQL_STMT *stmt;
482 va_list args;
483 mysql_enumerator_t *enumerator = NULL;
484 conn_t *conn;
485
486 conn = conn_get(this);
487 if (!conn)
488 {
489 return NULL;
490 }
491
492 va_start(args, sql);
493 stmt = run(conn->mysql, sql, &args);
494 if (stmt)
495 {
496 int columns, i;
497
498 enumerator = malloc_thing(mysql_enumerator_t);
499 enumerator->public.enumerate = (void*)mysql_enumerator_enumerate;
500 enumerator->public.destroy = (void*)mysql_enumerator_destroy;
501 enumerator->stmt = stmt;
502 enumerator->conn = conn;
503 columns = mysql_stmt_field_count(stmt);
504 enumerator->bind = calloc(columns, sizeof(MYSQL_BIND));
505 enumerator->length = calloc(columns, sizeof(unsigned long));
506 enumerator->val.p_void = calloc(columns, sizeof(enumerator->val));
507 for (i = 0; i < columns; i++)
508 {
509 switch (va_arg(args, db_type_t))
510 {
511 case DB_INT:
512 {
513 enumerator->bind[i].buffer_type = MYSQL_TYPE_LONG;
514 enumerator->bind[i].buffer = (char*)&enumerator->val.p_int[i];
515 break;
516 }
517 case DB_UINT:
518 {
519 enumerator->bind[i].buffer_type = MYSQL_TYPE_LONG;
520 enumerator->bind[i].buffer = (char*)&enumerator->val.p_uint[i];
521 enumerator->bind[i].is_unsigned = TRUE;
522 break;
523 }
524 case DB_TEXT:
525 {
526 enumerator->bind[i].buffer_type = MYSQL_TYPE_STRING;
527 enumerator->bind[i].length = &enumerator->length[i];
528 break;
529 }
530 case DB_BLOB:
531 {
532 enumerator->bind[i].buffer_type = MYSQL_TYPE_BLOB;
533 enumerator->bind[i].length = &enumerator->length[i];
534 break;
535 }
536 case DB_DOUBLE:
537 {
538 enumerator->bind[i].buffer_type = MYSQL_TYPE_DOUBLE;
539 enumerator->bind[i].buffer = (char*)&enumerator->val.p_double[i];
540 break;
541 }
542 default:
543 DBG1(DBG_LIB, "invalid result data type supplied");
544 mysql_enumerator_destroy(enumerator);
545 va_end(args);
546 return NULL;
547 }
548 }
549 if (mysql_stmt_bind_result(stmt, enumerator->bind))
550 {
551 DBG1(DBG_LIB, "binding MySQL result failed: %s",
552 mysql_stmt_error(stmt));
553 mysql_enumerator_destroy(enumerator);
554 enumerator = NULL;
555 }
556 }
557 else
558 {
559 conn_release(conn);
560 }
561 va_end(args);
562 return (enumerator_t*)enumerator;
563 }
564
565 METHOD(database_t, execute, int,
566 private_mysql_database_t *this, int *rowid, char *sql, ...)
567 {
568 MYSQL_STMT *stmt;
569 va_list args;
570 conn_t *conn;
571 int affected = -1;
572
573 conn = conn_get(this);
574 if (!conn)
575 {
576 return -1;
577 }
578 va_start(args, sql);
579 stmt = run(conn->mysql, sql, &args);
580 if (stmt)
581 {
582 if (rowid)
583 {
584 *rowid = mysql_stmt_insert_id(stmt);
585 }
586 affected = mysql_stmt_affected_rows(stmt);
587 mysql_stmt_close(stmt);
588 }
589 va_end(args);
590 conn_release(conn);
591 return affected;
592 }
593
594 METHOD(database_t, get_driver,db_driver_t,
595 private_mysql_database_t *this)
596 {
597 return DB_MYSQL;
598 }
599
600 METHOD(database_t, destroy, void,
601 private_mysql_database_t *this)
602 {
603 this->pool->destroy_function(this->pool, (void*)conn_destroy);
604 this->mutex->destroy(this->mutex);
605 free(this->host);
606 free(this->username);
607 free(this->password);
608 free(this->database);
609 free(this);
610 }
611
612 static bool parse_uri(private_mysql_database_t *this, char *uri)
613 {
614 char *username, *password, *host, *port = "0", *database, *pos;
615
616 /**
617 * parse mysql://username:pass@host:port/database uri
618 */
619 username = strdupa(uri + 8);
620 pos = strchr(username, ':');
621 if (pos)
622 {
623 *pos = '\0';
624 password = pos + 1;
625 pos = strrchr(password, '@');
626 if (pos)
627 {
628 *pos = '\0';
629 host = pos + 1;
630 pos = strrchr(host, ':');
631 if (pos)
632 {
633 *pos = '\0';
634 port = pos + 1;
635 pos = strchr(port, '/');
636 }
637 else
638 {
639 pos = strchr(host, '/');
640 }
641 if (pos)
642 {
643 *pos = '\0';
644 database = pos + 1;
645
646 this->host = strdup(host);
647 this->username = strdup(username);
648 this->password = strdup(password);
649 this->database = strdup(database);
650 this->port = atoi(port);
651 return TRUE;
652 }
653 }
654 }
655 DBG1(DBG_LIB, "parsing MySQL database uri '%s' failed", uri);
656 return FALSE;
657 }
658
659
660 /*
661 * see header file
662 */
663 mysql_database_t *mysql_database_create(char *uri)
664 {
665 conn_t *conn;
666 private_mysql_database_t *this;
667
668 if (!strneq(uri, "mysql://", 8))
669 {
670 return NULL;
671 }
672
673 INIT(this,
674 .public = {
675 .db = {
676 .query = _query,
677 .execute = _execute,
678 .get_driver = _get_driver,
679 .destroy = _destroy,
680 },
681 },
682 );
683
684 if (!parse_uri(this, uri))
685 {
686 free(this);
687 return NULL;
688 }
689 this->mutex = mutex_create(MUTEX_TYPE_DEFAULT);
690 this->pool = linked_list_create();
691
692 /* check connectivity */
693 conn = conn_get(this);
694 if (!conn)
695 {
696 destroy(this);
697 return NULL;
698 }
699 conn_release(conn);
700 return &this->public;
701 }
702