database: Add interface to handle transactions
[strongswan.git] / src / libstrongswan / plugins / mysql / mysql_database.c
1 /*
2 * Copyright (C) 2013 Tobias Brunner
3 * Copyright (C) 2007 Martin Willi
4 * Hochschule fuer Technik Rapperswil
5 *
6 * This program is free software; you can redistribute it and/or modify it
7 * under the terms of the GNU General Public License as published by the
8 * Free Software Foundation; either version 2 of the License, or (at your
9 * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
10 *
11 * This program is distributed in the hope that it will be useful, but
12 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
13 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
14 * for more details.
15 */
16
17 #define _GNU_SOURCE
18 #include <string.h>
19 #include <mysql.h>
20
21 #include "mysql_database.h"
22
23 #include <utils/debug.h>
24 #include <utils/chunk.h>
25 #include <threading/thread_value.h>
26 #include <threading/mutex.h>
27 #include <collections/linked_list.h>
28
29 /* Older mysql.h headers do not define it, but we need it. It is not returned
30 * in in MySQL 4 by default, but by MySQL 5. To avoid this problem, we catch
31 * it in all cases. */
32 #ifndef MYSQL_DATA_TRUNCATED
33 #define MYSQL_DATA_TRUNCATED 101
34 #endif
35
36 typedef struct private_mysql_database_t private_mysql_database_t;
37
38 /**
39 * private data of mysql_database
40 */
41 struct private_mysql_database_t {
42
43 /**
44 * public functions
45 */
46 mysql_database_t public;
47
48 /**
49 * connection pool, contains conn_t
50 */
51 linked_list_t *pool;
52
53 /**
54 * mutex to lock pool
55 */
56 mutex_t *mutex;
57
58 /**
59 * hostname to connect to
60 */
61 char *host;
62
63 /**
64 * username to use
65 */
66 char *username;
67
68 /**
69 * password
70 */
71 char *password;
72
73 /**
74 * database name
75 */
76 char *database;
77
78 /**
79 * tcp port
80 */
81 int port;
82 };
83
84 typedef struct conn_t conn_t;
85
86 /**
87 * connection pool entry
88 */
89 struct conn_t {
90
91 /**
92 * MySQL database connection
93 */
94 MYSQL *mysql;
95
96 /**
97 * connection in use?
98 */
99 bool in_use;
100 };
101
102 /**
103 * Release a mysql connection
104 */
105 static void conn_release(private_mysql_database_t *this, conn_t *conn)
106 {
107 this->mutex->lock(this->mutex);
108 conn->in_use = FALSE;
109 this->mutex->unlock(this->mutex);
110 }
111
112 /**
113 * thread specific initialization flag
114 */
115 thread_value_t *initialized;
116
117 /**
118 * Initialize a thread for mysql usage
119 */
120 static void thread_initialize()
121 {
122 if (initialized->get(initialized) == NULL)
123 {
124 initialized->set(initialized, (void*)TRUE);
125 mysql_thread_init();
126 }
127 }
128
129 /**
130 * mysql library initialization function
131 */
132 bool mysql_database_init()
133 {
134 if (mysql_library_init(0, NULL, NULL))
135 {
136 return FALSE;
137 }
138 initialized = thread_value_create((thread_cleanup_t)mysql_thread_end);
139 return TRUE;
140 }
141
142 /**
143 * mysql library cleanup function
144 */
145 void mysql_database_deinit()
146 {
147 initialized->destroy(initialized);
148 mysql_thread_end();
149 mysql_library_end();
150 }
151
152 /**
153 * Destroy a mysql connection
154 */
155 static void conn_destroy(conn_t *this)
156 {
157 mysql_close(this->mysql);
158 free(this);
159 }
160
161 /**
162 * Acquire/Reuse a mysql connection
163 */
164 static conn_t *conn_get(private_mysql_database_t *this)
165 {
166 conn_t *current, *found = NULL;
167 enumerator_t *enumerator;
168
169 thread_initialize();
170
171 while (TRUE)
172 {
173 this->mutex->lock(this->mutex);
174 enumerator = this->pool->create_enumerator(this->pool);
175 while (enumerator->enumerate(enumerator, &current))
176 {
177 if (!current->in_use)
178 {
179 found = current;
180 found->in_use = TRUE;
181 break;
182 }
183 }
184 enumerator->destroy(enumerator);
185 this->mutex->unlock(this->mutex);
186 if (found)
187 { /* check connection if found, release if ping fails */
188 if (mysql_ping(found->mysql) == 0)
189 {
190 break;
191 }
192 this->mutex->lock(this->mutex);
193 this->pool->remove(this->pool, found, NULL);
194 this->mutex->unlock(this->mutex);
195 conn_destroy(found);
196 found = NULL;
197 continue;
198 }
199 break;
200 }
201 if (found == NULL)
202 {
203 INIT(found,
204 .in_use = TRUE,
205 .mysql = mysql_init(NULL),
206 );
207 if (!mysql_real_connect(found->mysql, this->host, this->username,
208 this->password, this->database, this->port,
209 NULL, 0))
210 {
211 DBG1(DBG_LIB, "connecting to mysql://%s:***@%s:%d/%s failed: %s",
212 this->username, this->host, this->port, this->database,
213 mysql_error(found->mysql));
214 conn_destroy(found);
215 found = NULL;
216 }
217 else
218 {
219 this->mutex->lock(this->mutex);
220 this->pool->insert_last(this->pool, found);
221 DBG2(DBG_LIB, "increased MySQL connection pool size to %d",
222 this->pool->get_count(this->pool));
223 this->mutex->unlock(this->mutex);
224 }
225 }
226 return found;
227 }
228
229 /**
230 * Create and run a MySQL stmt using a sql string and args
231 */
232 static MYSQL_STMT* run(MYSQL *mysql, char *sql, va_list *args)
233 {
234 MYSQL_STMT *stmt;
235 int params;
236
237 stmt = mysql_stmt_init(mysql);
238 if (stmt == NULL)
239 {
240 DBG1(DBG_LIB, "creating MySQL statement failed: %s",
241 mysql_error(mysql));
242 return NULL;
243 }
244 if (mysql_stmt_prepare(stmt, sql, strlen(sql)))
245 {
246 DBG1(DBG_LIB, "preparing MySQL statement failed: %s",
247 mysql_stmt_error(stmt));
248 mysql_stmt_close(stmt);
249 return NULL;
250 }
251 params = mysql_stmt_param_count(stmt);
252 if (params > 0)
253 {
254 int i;
255 MYSQL_BIND *bind;
256
257 bind = alloca(sizeof(MYSQL_BIND) * params);
258 memset(bind, 0, sizeof(MYSQL_BIND) * params);
259
260 for (i = 0; i < params; i++)
261 {
262 switch (va_arg(*args, db_type_t))
263 {
264 case DB_INT:
265 {
266 bind[i].buffer_type = MYSQL_TYPE_LONG;
267 bind[i].buffer = (char*)alloca(sizeof(int));
268 *(int*)bind[i].buffer = va_arg(*args, int);
269 bind[i].buffer_length = sizeof(int);
270 break;
271 }
272 case DB_UINT:
273 {
274 bind[i].buffer_type = MYSQL_TYPE_LONG;
275 bind[i].buffer = (char*)alloca(sizeof(u_int));
276 *(u_int*)bind[i].buffer = va_arg(*args, u_int);
277 bind[i].buffer_length = sizeof(u_int);
278 bind[i].is_unsigned = TRUE;
279 break;
280 }
281 case DB_TEXT:
282 {
283 bind[i].buffer_type = MYSQL_TYPE_STRING;;
284 bind[i].buffer = va_arg(*args, char*);
285 if (bind[i].buffer)
286 {
287 bind[i].buffer_length = strlen(bind[i].buffer);
288 }
289 break;
290 }
291 case DB_BLOB:
292 {
293 chunk_t chunk = va_arg(*args, chunk_t);
294 bind[i].buffer_type = MYSQL_TYPE_BLOB;
295 bind[i].buffer = chunk.ptr;
296 bind[i].buffer_length = chunk.len;
297 break;
298 }
299 case DB_DOUBLE:
300 {
301 bind[i].buffer_type = MYSQL_TYPE_DOUBLE;
302 bind[i].buffer = (char*)alloca(sizeof(double));
303 *(double*)bind[i].buffer = va_arg(*args, double);
304 bind[i].buffer_length = sizeof(double);
305 break;
306 }
307 case DB_NULL:
308 {
309 bind[i].buffer_type = MYSQL_TYPE_NULL;
310 break;
311 }
312 default:
313 DBG1(DBG_LIB, "invalid data type supplied");
314 mysql_stmt_close(stmt);
315 return NULL;
316 }
317 }
318 if (mysql_stmt_bind_param(stmt, bind))
319 {
320 DBG1(DBG_LIB, "binding MySQL param failed: %s",
321 mysql_stmt_error(stmt));
322 mysql_stmt_close(stmt);
323 return NULL;
324 }
325 }
326 if (mysql_stmt_execute(stmt))
327 {
328 DBG1(DBG_LIB, "executing MySQL statement failed: %s",
329 mysql_stmt_error(stmt));
330 mysql_stmt_close(stmt);
331 return NULL;
332 }
333 return stmt;
334 }
335
336 typedef struct {
337 /** implements enumerator_t */
338 enumerator_t public;
339 /** mysql database */
340 private_mysql_database_t *db;
341 /** associated MySQL statement */
342 MYSQL_STMT *stmt;
343 /** result bindings */
344 MYSQL_BIND *bind;
345 /** pooled connection handle */
346 conn_t *conn;
347 /** value for INT, UINT, double */
348 union {
349 void *p_void;;
350 int *p_int;
351 u_int *p_uint;
352 double *p_double;
353 } val;
354 /* length for TEXT and BLOB */
355 unsigned long *length;
356 } mysql_enumerator_t;
357
358 /**
359 * create a mysql enumerator
360 */
361 static void mysql_enumerator_destroy(mysql_enumerator_t *this)
362 {
363 int columns, i;
364
365 columns = mysql_stmt_field_count(this->stmt);
366
367 for (i = 0; i < columns; i++)
368 {
369 switch (this->bind[i].buffer_type)
370 {
371 case MYSQL_TYPE_STRING:
372 case MYSQL_TYPE_BLOB:
373 {
374 free(this->bind[i].buffer);
375 break;
376 }
377 default:
378 break;
379 }
380 }
381 mysql_stmt_close(this->stmt);
382 conn_release(this->db, this->conn);
383 free(this->bind);
384 free(this->val.p_void);
385 free(this->length);
386 free(this);
387 }
388
389 /**
390 * Implementation of database.query().enumerate
391 */
392 static bool mysql_enumerator_enumerate(mysql_enumerator_t *this, ...)
393 {
394 int i, columns;
395 va_list args;
396
397 columns = mysql_stmt_field_count(this->stmt);
398
399 /* free/reset data set of previous call */
400 for (i = 0; i < columns; i++)
401 {
402 switch (this->bind[i].buffer_type)
403 {
404 case MYSQL_TYPE_STRING:
405 case MYSQL_TYPE_BLOB:
406 {
407 free(this->bind[i].buffer);
408 this->bind[i].buffer = NULL;
409 this->bind[i].buffer_length = 0;
410 this->bind[i].length = &this->length[i];
411 this->length[i] = 0;
412 break;
413 }
414 default:
415 break;
416 }
417 }
418
419 switch (mysql_stmt_fetch(this->stmt))
420 {
421 case 0:
422 case MYSQL_DATA_TRUNCATED:
423 break;
424 case MYSQL_NO_DATA:
425 return FALSE;
426 default:
427 DBG1(DBG_LIB, "fetching MySQL row failed: %s",
428 mysql_stmt_error(this->stmt));
429 return FALSE;
430 }
431
432 va_start(args, this);
433 for (i = 0; i < columns; i++)
434 {
435 switch (this->bind[i].buffer_type)
436 {
437 case MYSQL_TYPE_LONG:
438 {
439 if (this->bind[i].is_unsigned)
440 {
441 u_int *value = va_arg(args, u_int*);
442 *value = this->val.p_uint[i];
443 }
444 else
445 {
446 int *value = va_arg(args, int*);
447 *value = this->val.p_int[i];
448 }
449 break;
450 }
451 case MYSQL_TYPE_STRING:
452 {
453 char **value = va_arg(args, char**);
454 this->bind[i].buffer = malloc(this->length[i]+1);
455 this->bind[i].buffer_length = this->length[i];
456 *value = this->bind[i].buffer;
457 mysql_stmt_fetch_column(this->stmt, &this->bind[i], i, 0);
458 ((char*)this->bind[i].buffer)[this->length[i]] = '\0';
459 break;
460 }
461 case MYSQL_TYPE_BLOB:
462 {
463 chunk_t *value = va_arg(args, chunk_t*);
464 this->bind[i].buffer = malloc(this->length[i]);
465 this->bind[i].buffer_length = this->length[i];
466 value->ptr = this->bind[i].buffer;
467 value->len = this->length[i];
468 mysql_stmt_fetch_column(this->stmt, &this->bind[i], i, 0);
469 break;
470 }
471 case MYSQL_TYPE_DOUBLE:
472 {
473 double *value = va_arg(args, double*);
474 *value = this->val.p_double[i];
475 break;
476 }
477 default:
478 break;
479 }
480 }
481 va_end(args);
482 return TRUE;
483 }
484
485 METHOD(database_t, query, enumerator_t*,
486 private_mysql_database_t *this, char *sql, ...)
487 {
488 MYSQL_STMT *stmt;
489 va_list args;
490 mysql_enumerator_t *enumerator = NULL;
491 conn_t *conn;
492
493 conn = conn_get(this);
494 if (!conn)
495 {
496 return NULL;
497 }
498
499 va_start(args, sql);
500 stmt = run(conn->mysql, sql, &args);
501 if (stmt)
502 {
503 int columns, i;
504
505 INIT(enumerator,
506 .public = {
507 .enumerate = (void*)mysql_enumerator_enumerate,
508 .destroy = (void*)mysql_enumerator_destroy,
509
510 },
511 .db = this,
512 .stmt = stmt,
513 .conn = conn,
514 );
515 columns = mysql_stmt_field_count(stmt);
516 enumerator->bind = calloc(columns, sizeof(MYSQL_BIND));
517 enumerator->length = calloc(columns, sizeof(unsigned long));
518 enumerator->val.p_void = calloc(columns, sizeof(enumerator->val));
519 for (i = 0; i < columns; i++)
520 {
521 switch (va_arg(args, db_type_t))
522 {
523 case DB_INT:
524 {
525 enumerator->bind[i].buffer_type = MYSQL_TYPE_LONG;
526 enumerator->bind[i].buffer = (char*)&enumerator->val.p_int[i];
527 break;
528 }
529 case DB_UINT:
530 {
531 enumerator->bind[i].buffer_type = MYSQL_TYPE_LONG;
532 enumerator->bind[i].buffer = (char*)&enumerator->val.p_uint[i];
533 enumerator->bind[i].is_unsigned = TRUE;
534 break;
535 }
536 case DB_TEXT:
537 {
538 enumerator->bind[i].buffer_type = MYSQL_TYPE_STRING;
539 enumerator->bind[i].length = &enumerator->length[i];
540 break;
541 }
542 case DB_BLOB:
543 {
544 enumerator->bind[i].buffer_type = MYSQL_TYPE_BLOB;
545 enumerator->bind[i].length = &enumerator->length[i];
546 break;
547 }
548 case DB_DOUBLE:
549 {
550 enumerator->bind[i].buffer_type = MYSQL_TYPE_DOUBLE;
551 enumerator->bind[i].buffer = (char*)&enumerator->val.p_double[i];
552 break;
553 }
554 default:
555 DBG1(DBG_LIB, "invalid result data type supplied");
556 mysql_enumerator_destroy(enumerator);
557 va_end(args);
558 return NULL;
559 }
560 }
561 if (mysql_stmt_bind_result(stmt, enumerator->bind))
562 {
563 DBG1(DBG_LIB, "binding MySQL result failed: %s",
564 mysql_stmt_error(stmt));
565 mysql_enumerator_destroy(enumerator);
566 enumerator = NULL;
567 }
568 }
569 else
570 {
571 conn_release(this, conn);
572 }
573 va_end(args);
574 return (enumerator_t*)enumerator;
575 }
576
577 METHOD(database_t, execute, int,
578 private_mysql_database_t *this, int *rowid, char *sql, ...)
579 {
580 MYSQL_STMT *stmt;
581 va_list args;
582 conn_t *conn;
583 int affected = -1;
584
585 conn = conn_get(this);
586 if (!conn)
587 {
588 return -1;
589 }
590 va_start(args, sql);
591 stmt = run(conn->mysql, sql, &args);
592 if (stmt)
593 {
594 if (rowid)
595 {
596 *rowid = mysql_stmt_insert_id(stmt);
597 }
598 affected = mysql_stmt_affected_rows(stmt);
599 mysql_stmt_close(stmt);
600 }
601 va_end(args);
602 conn_release(this, conn);
603 return affected;
604 }
605
606 METHOD(database_t, transaction, bool,
607 private_mysql_database_t *this)
608 {
609 return FALSE;
610 }
611
612 METHOD(database_t, commit, bool,
613 private_mysql_database_t *this)
614 {
615 return FALSE;
616 }
617
618 METHOD(database_t, rollback, bool,
619 private_mysql_database_t *this)
620 {
621 return FALSE;
622 }
623
624 METHOD(database_t, get_driver,db_driver_t,
625 private_mysql_database_t *this)
626 {
627 return DB_MYSQL;
628 }
629
630 METHOD(database_t, destroy, void,
631 private_mysql_database_t *this)
632 {
633 this->pool->destroy_function(this->pool, (void*)conn_destroy);
634 this->mutex->destroy(this->mutex);
635 free(this->host);
636 free(this->username);
637 free(this->password);
638 free(this->database);
639 free(this);
640 }
641
642 static bool parse_uri(private_mysql_database_t *this, char *uri)
643 {
644 char *username, *password, *host, *port = "0", *database, *pos;
645
646 /**
647 * parse mysql://username:pass@host:port/database uri
648 */
649 username = strdupa(uri + 8);
650 pos = strchr(username, ':');
651 if (pos)
652 {
653 *pos = '\0';
654 password = pos + 1;
655 pos = strrchr(password, '@');
656 if (pos)
657 {
658 *pos = '\0';
659 host = pos + 1;
660 pos = strrchr(host, ':');
661 if (pos)
662 {
663 *pos = '\0';
664 port = pos + 1;
665 pos = strchr(port, '/');
666 }
667 else
668 {
669 pos = strchr(host, '/');
670 }
671 if (pos)
672 {
673 *pos = '\0';
674 database = pos + 1;
675
676 this->host = strdup(host);
677 this->username = strdup(username);
678 this->password = strdup(password);
679 this->database = strdup(database);
680 this->port = atoi(port);
681 return TRUE;
682 }
683 }
684 }
685 DBG1(DBG_LIB, "parsing MySQL database uri '%s' failed", uri);
686 return FALSE;
687 }
688
689
690 /*
691 * see header file
692 */
693 mysql_database_t *mysql_database_create(char *uri)
694 {
695 conn_t *conn;
696 private_mysql_database_t *this;
697
698 if (!strpfx(uri, "mysql://"))
699 {
700 return NULL;
701 }
702
703 INIT(this,
704 .public = {
705 .db = {
706 .query = _query,
707 .execute = _execute,
708 .transaction = _transaction,
709 .commit = _commit,
710 .rollback = _rollback,
711 .get_driver = _get_driver,
712 .destroy = _destroy,
713 },
714 },
715 );
716
717 if (!parse_uri(this, uri))
718 {
719 free(this);
720 return NULL;
721 }
722 this->mutex = mutex_create(MUTEX_TYPE_DEFAULT);
723 this->pool = linked_list_create();
724
725 /* check connectivity */
726 conn = conn_get(this);
727 if (!conn)
728 {
729 destroy(this);
730 return NULL;
731 }
732 conn_release(this, conn);
733 return &this->public;
734 }