7e1da683e648a6c4d0b359531af2f474f5633701
[strongswan.git] / src / libstrongswan / plugins / mysql / mysql_database.c
1 /*
2 * Copyright (C) 2007 Martin Willi
3 * Hochschule fuer Technik Rapperswil
4 *
5 * This program is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License as published by the
7 * Free Software Foundation; either version 2 of the License, or (at your
8 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
9 *
10 * This program is distributed in the hope that it will be useful, but
11 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
12 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
13 * for more details.
14 */
15
16 #define _GNU_SOURCE
17 #include <string.h>
18 #include <mysql.h>
19
20 #include "mysql_database.h"
21
22 #include <utils/debug.h>
23 #include <utils/chunk.h>
24 #include <threading/thread_value.h>
25 #include <threading/mutex.h>
26 #include <collections/linked_list.h>
27
28 /* Older mysql.h headers do not define it, but we need it. It is not returned
29 * in in MySQL 4 by default, but by MySQL 5. To avoid this problem, we catch
30 * it in all cases. */
31 #ifndef MYSQL_DATA_TRUNCATED
32 #define MYSQL_DATA_TRUNCATED 101
33 #endif
34
35 typedef struct private_mysql_database_t private_mysql_database_t;
36
37 /**
38 * private data of mysql_database
39 */
40 struct private_mysql_database_t {
41
42 /**
43 * public functions
44 */
45 mysql_database_t public;
46
47 /**
48 * connection pool, contains conn_t
49 */
50 linked_list_t *pool;
51
52 /**
53 * mutex to lock pool
54 */
55 mutex_t *mutex;
56
57 /**
58 * hostname to connect to
59 */
60 char *host;
61
62 /**
63 * username to use
64 */
65 char *username;
66
67 /**
68 * password
69 */
70 char *password;
71
72 /**
73 * database name
74 */
75 char *database;
76
77 /**
78 * tcp port
79 */
80 int port;
81 };
82
83 typedef struct conn_t conn_t;
84
85 /**
86 * connection pool entry
87 */
88 struct conn_t {
89
90 /**
91 * MySQL database connection
92 */
93 MYSQL *mysql;
94
95 /**
96 * connection in use?
97 */
98 bool in_use;
99 };
100
101 /**
102 * Release a mysql connection
103 */
104 static void conn_release(conn_t *conn)
105 {
106 conn->in_use = FALSE;
107 }
108
109 /**
110 * thread specific initialization flag
111 */
112 thread_value_t *initialized;
113
114 /**
115 * Initialize a thread for mysql usage
116 */
117 static void thread_initialize()
118 {
119 if (initialized->get(initialized) == NULL)
120 {
121 initialized->set(initialized, (void*)TRUE);
122 mysql_thread_init();
123 }
124 }
125
126 /**
127 * mysql library initialization function
128 */
129 bool mysql_database_init()
130 {
131 if (mysql_library_init(0, NULL, NULL))
132 {
133 return FALSE;
134 }
135 initialized = thread_value_create((thread_cleanup_t)mysql_thread_end);
136 return TRUE;
137 }
138
139 /**
140 * mysql library cleanup function
141 */
142 void mysql_database_deinit()
143 {
144 initialized->destroy(initialized);
145 mysql_thread_end();
146 /* mysql_library_end(); would be the clean way, however, it hangs... */
147 }
148
149 /**
150 * Destroy a mysql connection
151 */
152 static void conn_destroy(conn_t *this)
153 {
154 mysql_close(this->mysql);
155 free(this);
156 }
157
158 /**
159 * Acquire/Reuse a mysql connection
160 */
161 static conn_t *conn_get(private_mysql_database_t *this)
162 {
163 conn_t *current, *found = NULL;
164 enumerator_t *enumerator;
165
166 thread_initialize();
167
168 while (TRUE)
169 {
170 this->mutex->lock(this->mutex);
171 enumerator = this->pool->create_enumerator(this->pool);
172 while (enumerator->enumerate(enumerator, &current))
173 {
174 if (!current->in_use)
175 {
176 found = current;
177 found->in_use = TRUE;
178 break;
179 }
180 }
181 enumerator->destroy(enumerator);
182 this->mutex->unlock(this->mutex);
183 if (found)
184 { /* check connection if found, release if ping fails */
185 if (mysql_ping(found->mysql) == 0)
186 {
187 break;
188 }
189 this->mutex->lock(this->mutex);
190 this->pool->remove(this->pool, found, NULL);
191 this->mutex->unlock(this->mutex);
192 conn_destroy(found);
193 found = NULL;
194 continue;
195 }
196 break;
197 }
198 if (found == NULL)
199 {
200 found = malloc_thing(conn_t);
201 found->in_use = TRUE;
202 found->mysql = mysql_init(NULL);
203 if (!mysql_real_connect(found->mysql, this->host, this->username,
204 this->password, this->database, this->port,
205 NULL, 0))
206 {
207 DBG1(DBG_LIB, "connecting to mysql://%s:***@%s:%d/%s failed: %s",
208 this->username, this->host, this->port, this->database,
209 mysql_error(found->mysql));
210 conn_destroy(found);
211 found = NULL;
212 }
213 else
214 {
215 this->mutex->lock(this->mutex);
216 this->pool->insert_last(this->pool, found);
217 DBG2(DBG_LIB, "increased MySQL connection pool size to %d",
218 this->pool->get_count(this->pool));
219 this->mutex->unlock(this->mutex);
220 }
221 }
222 return found;
223 }
224
225 /**
226 * Create and run a MySQL stmt using a sql string and args
227 */
228 static MYSQL_STMT* run(MYSQL *mysql, char *sql, va_list *args)
229 {
230 MYSQL_STMT *stmt;
231 int params;
232
233 stmt = mysql_stmt_init(mysql);
234 if (stmt == NULL)
235 {
236 DBG1(DBG_LIB, "creating MySQL statement failed: %s",
237 mysql_error(mysql));
238 return NULL;
239 }
240 if (mysql_stmt_prepare(stmt, sql, strlen(sql)))
241 {
242 DBG1(DBG_LIB, "preparing MySQL statement failed: %s",
243 mysql_stmt_error(stmt));
244 mysql_stmt_close(stmt);
245 return NULL;
246 }
247 params = mysql_stmt_param_count(stmt);
248 if (params > 0)
249 {
250 int i;
251 MYSQL_BIND *bind;
252
253 bind = alloca(sizeof(MYSQL_BIND) * params);
254 memset(bind, 0, sizeof(MYSQL_BIND) * params);
255
256 for (i = 0; i < params; i++)
257 {
258 switch (va_arg(*args, db_type_t))
259 {
260 case DB_INT:
261 {
262 bind[i].buffer_type = MYSQL_TYPE_LONG;
263 bind[i].buffer = (char*)alloca(sizeof(int));
264 *(int*)bind[i].buffer = va_arg(*args, int);
265 bind[i].buffer_length = sizeof(int);
266 break;
267 }
268 case DB_UINT:
269 {
270 bind[i].buffer_type = MYSQL_TYPE_LONG;
271 bind[i].buffer = (char*)alloca(sizeof(u_int));
272 *(u_int*)bind[i].buffer = va_arg(*args, u_int);
273 bind[i].buffer_length = sizeof(u_int);
274 bind[i].is_unsigned = TRUE;
275 break;
276 }
277 case DB_TEXT:
278 {
279 bind[i].buffer_type = MYSQL_TYPE_STRING;;
280 bind[i].buffer = va_arg(*args, char*);
281 if (bind[i].buffer)
282 {
283 bind[i].buffer_length = strlen(bind[i].buffer);
284 }
285 break;
286 }
287 case DB_BLOB:
288 {
289 chunk_t chunk = va_arg(*args, chunk_t);
290 bind[i].buffer_type = MYSQL_TYPE_BLOB;
291 bind[i].buffer = chunk.ptr;
292 bind[i].buffer_length = chunk.len;
293 break;
294 }
295 case DB_DOUBLE:
296 {
297 bind[i].buffer_type = MYSQL_TYPE_DOUBLE;
298 bind[i].buffer = (char*)alloca(sizeof(double));
299 *(double*)bind[i].buffer = va_arg(*args, double);
300 bind[i].buffer_length = sizeof(double);
301 break;
302 }
303 case DB_NULL:
304 {
305 bind[i].buffer_type = MYSQL_TYPE_NULL;
306 break;
307 }
308 default:
309 DBG1(DBG_LIB, "invalid data type supplied");
310 mysql_stmt_close(stmt);
311 return NULL;
312 }
313 }
314 if (mysql_stmt_bind_param(stmt, bind))
315 {
316 DBG1(DBG_LIB, "binding MySQL param failed: %s",
317 mysql_stmt_error(stmt));
318 mysql_stmt_close(stmt);
319 return NULL;
320 }
321 }
322 if (mysql_stmt_execute(stmt))
323 {
324 DBG1(DBG_LIB, "executing MySQL statement failed: %s",
325 mysql_stmt_error(stmt));
326 mysql_stmt_close(stmt);
327 return NULL;
328 }
329 return stmt;
330 }
331
332 typedef struct {
333 /** implements enumerator_t */
334 enumerator_t public;
335 /** associated MySQL statement */
336 MYSQL_STMT *stmt;
337 /** result bindings */
338 MYSQL_BIND *bind;
339 /** pooled connection handle */
340 conn_t *conn;
341 /** value for INT, UINT, double */
342 union {
343 void *p_void;;
344 int *p_int;
345 u_int *p_uint;
346 double *p_double;
347 } val;
348 /* length for TEXT and BLOB */
349 unsigned long *length;
350 } mysql_enumerator_t;
351
352 /**
353 * create a mysql enumerator
354 */
355 static void mysql_enumerator_destroy(mysql_enumerator_t *this)
356 {
357 int columns, i;
358
359 columns = mysql_stmt_field_count(this->stmt);
360
361 for (i = 0; i < columns; i++)
362 {
363 switch (this->bind[i].buffer_type)
364 {
365 case MYSQL_TYPE_STRING:
366 case MYSQL_TYPE_BLOB:
367 {
368 free(this->bind[i].buffer);
369 break;
370 }
371 default:
372 break;
373 }
374 }
375 mysql_stmt_close(this->stmt);
376 conn_release(this->conn);
377 free(this->bind);
378 free(this->val.p_void);
379 free(this->length);
380 free(this);
381 }
382
383 /**
384 * Implementation of database.query().enumerate
385 */
386 static bool mysql_enumerator_enumerate(mysql_enumerator_t *this, ...)
387 {
388 int i, columns;
389 va_list args;
390
391 columns = mysql_stmt_field_count(this->stmt);
392
393 /* free/reset data set of previous call */
394 for (i = 0; i < columns; i++)
395 {
396 switch (this->bind[i].buffer_type)
397 {
398 case MYSQL_TYPE_STRING:
399 case MYSQL_TYPE_BLOB:
400 {
401 free(this->bind[i].buffer);
402 this->bind[i].buffer = NULL;
403 this->bind[i].buffer_length = 0;
404 this->bind[i].length = &this->length[i];
405 this->length[i] = 0;
406 break;
407 }
408 default:
409 break;
410 }
411 }
412
413 switch (mysql_stmt_fetch(this->stmt))
414 {
415 case 0:
416 case MYSQL_DATA_TRUNCATED:
417 break;
418 case MYSQL_NO_DATA:
419 return FALSE;
420 default:
421 DBG1(DBG_LIB, "fetching MySQL row failed: %s",
422 mysql_stmt_error(this->stmt));
423 return FALSE;
424 }
425
426 va_start(args, this);
427 for (i = 0; i < columns; i++)
428 {
429 switch (this->bind[i].buffer_type)
430 {
431 case MYSQL_TYPE_LONG:
432 {
433 if (this->bind[i].is_unsigned)
434 {
435 u_int *value = va_arg(args, u_int*);
436 *value = this->val.p_uint[i];
437 }
438 else
439 {
440 int *value = va_arg(args, int*);
441 *value = this->val.p_int[i];
442 }
443 break;
444 }
445 case MYSQL_TYPE_STRING:
446 {
447 char **value = va_arg(args, char**);
448 this->bind[i].buffer = malloc(this->length[i]+1);
449 this->bind[i].buffer_length = this->length[i];
450 *value = this->bind[i].buffer;
451 mysql_stmt_fetch_column(this->stmt, &this->bind[i], i, 0);
452 ((char*)this->bind[i].buffer)[this->length[i]] = '\0';
453 break;
454 }
455 case MYSQL_TYPE_BLOB:
456 {
457 chunk_t *value = va_arg(args, chunk_t*);
458 this->bind[i].buffer = malloc(this->length[i]);
459 this->bind[i].buffer_length = this->length[i];
460 value->ptr = this->bind[i].buffer;
461 value->len = this->length[i];
462 mysql_stmt_fetch_column(this->stmt, &this->bind[i], i, 0);
463 break;
464 }
465 case MYSQL_TYPE_DOUBLE:
466 {
467 double *value = va_arg(args, double*);
468 *value = this->val.p_double[i];
469 break;
470 }
471 default:
472 break;
473 }
474 }
475 va_end(args);
476 return TRUE;
477 }
478
479 METHOD(database_t, query, enumerator_t*,
480 private_mysql_database_t *this, char *sql, ...)
481 {
482 MYSQL_STMT *stmt;
483 va_list args;
484 mysql_enumerator_t *enumerator = NULL;
485 conn_t *conn;
486
487 conn = conn_get(this);
488 if (!conn)
489 {
490 return NULL;
491 }
492
493 va_start(args, sql);
494 stmt = run(conn->mysql, sql, &args);
495 if (stmt)
496 {
497 int columns, i;
498
499 enumerator = malloc_thing(mysql_enumerator_t);
500 enumerator->public.enumerate = (void*)mysql_enumerator_enumerate;
501 enumerator->public.destroy = (void*)mysql_enumerator_destroy;
502 enumerator->stmt = stmt;
503 enumerator->conn = conn;
504 columns = mysql_stmt_field_count(stmt);
505 enumerator->bind = calloc(columns, sizeof(MYSQL_BIND));
506 enumerator->length = calloc(columns, sizeof(unsigned long));
507 enumerator->val.p_void = calloc(columns, sizeof(enumerator->val));
508 for (i = 0; i < columns; i++)
509 {
510 switch (va_arg(args, db_type_t))
511 {
512 case DB_INT:
513 {
514 enumerator->bind[i].buffer_type = MYSQL_TYPE_LONG;
515 enumerator->bind[i].buffer = (char*)&enumerator->val.p_int[i];
516 break;
517 }
518 case DB_UINT:
519 {
520 enumerator->bind[i].buffer_type = MYSQL_TYPE_LONG;
521 enumerator->bind[i].buffer = (char*)&enumerator->val.p_uint[i];
522 enumerator->bind[i].is_unsigned = TRUE;
523 break;
524 }
525 case DB_TEXT:
526 {
527 enumerator->bind[i].buffer_type = MYSQL_TYPE_STRING;
528 enumerator->bind[i].length = &enumerator->length[i];
529 break;
530 }
531 case DB_BLOB:
532 {
533 enumerator->bind[i].buffer_type = MYSQL_TYPE_BLOB;
534 enumerator->bind[i].length = &enumerator->length[i];
535 break;
536 }
537 case DB_DOUBLE:
538 {
539 enumerator->bind[i].buffer_type = MYSQL_TYPE_DOUBLE;
540 enumerator->bind[i].buffer = (char*)&enumerator->val.p_double[i];
541 break;
542 }
543 default:
544 DBG1(DBG_LIB, "invalid result data type supplied");
545 mysql_enumerator_destroy(enumerator);
546 va_end(args);
547 return NULL;
548 }
549 }
550 if (mysql_stmt_bind_result(stmt, enumerator->bind))
551 {
552 DBG1(DBG_LIB, "binding MySQL result failed: %s",
553 mysql_stmt_error(stmt));
554 mysql_enumerator_destroy(enumerator);
555 enumerator = NULL;
556 }
557 }
558 else
559 {
560 conn_release(conn);
561 }
562 va_end(args);
563 return (enumerator_t*)enumerator;
564 }
565
566 METHOD(database_t, execute, int,
567 private_mysql_database_t *this, int *rowid, char *sql, ...)
568 {
569 MYSQL_STMT *stmt;
570 va_list args;
571 conn_t *conn;
572 int affected = -1;
573
574 conn = conn_get(this);
575 if (!conn)
576 {
577 return -1;
578 }
579 va_start(args, sql);
580 stmt = run(conn->mysql, sql, &args);
581 if (stmt)
582 {
583 if (rowid)
584 {
585 *rowid = mysql_stmt_insert_id(stmt);
586 }
587 affected = mysql_stmt_affected_rows(stmt);
588 mysql_stmt_close(stmt);
589 }
590 va_end(args);
591 conn_release(conn);
592 return affected;
593 }
594
595 METHOD(database_t, get_driver,db_driver_t,
596 private_mysql_database_t *this)
597 {
598 return DB_MYSQL;
599 }
600
601 METHOD(database_t, destroy, void,
602 private_mysql_database_t *this)
603 {
604 this->pool->destroy_function(this->pool, (void*)conn_destroy);
605 this->mutex->destroy(this->mutex);
606 free(this->host);
607 free(this->username);
608 free(this->password);
609 free(this->database);
610 free(this);
611 }
612
613 static bool parse_uri(private_mysql_database_t *this, char *uri)
614 {
615 char *username, *password, *host, *port = "0", *database, *pos;
616
617 /**
618 * parse mysql://username:pass@host:port/database uri
619 */
620 username = strdupa(uri + 8);
621 pos = strchr(username, ':');
622 if (pos)
623 {
624 *pos = '\0';
625 password = pos + 1;
626 pos = strrchr(password, '@');
627 if (pos)
628 {
629 *pos = '\0';
630 host = pos + 1;
631 pos = strrchr(host, ':');
632 if (pos)
633 {
634 *pos = '\0';
635 port = pos + 1;
636 pos = strchr(port, '/');
637 }
638 else
639 {
640 pos = strchr(host, '/');
641 }
642 if (pos)
643 {
644 *pos = '\0';
645 database = pos + 1;
646
647 this->host = strdup(host);
648 this->username = strdup(username);
649 this->password = strdup(password);
650 this->database = strdup(database);
651 this->port = atoi(port);
652 return TRUE;
653 }
654 }
655 }
656 DBG1(DBG_LIB, "parsing MySQL database uri '%s' failed", uri);
657 return FALSE;
658 }
659
660
661 /*
662 * see header file
663 */
664 mysql_database_t *mysql_database_create(char *uri)
665 {
666 conn_t *conn;
667 private_mysql_database_t *this;
668
669 if (!strneq(uri, "mysql://", 8))
670 {
671 return NULL;
672 }
673
674 INIT(this,
675 .public = {
676 .db = {
677 .query = _query,
678 .execute = _execute,
679 .get_driver = _get_driver,
680 .destroy = _destroy,
681 },
682 },
683 );
684
685 if (!parse_uri(this, uri))
686 {
687 free(this);
688 return NULL;
689 }
690 this->mutex = mutex_create(MUTEX_TYPE_DEFAULT);
691 this->pool = linked_list_create();
692
693 /* check connectivity */
694 conn = conn_get(this);
695 if (!conn)
696 {
697 destroy(this);
698 return NULL;
699 }
700 conn_release(conn);
701 return &this->public;
702 }
703