condvar->wait() can handle recursive mutex
authorMartin Willi <martin@strongswan.org>
Thu, 16 Oct 2008 11:29:42 +0000 (11:29 -0000)
committerMartin Willi <martin@strongswan.org>
Thu, 16 Oct 2008 11:29:42 +0000 (11:29 -0000)
src/libstrongswan/utils/mutex.c

index f16c5b2..1c3185d 100644 (file)
 
 #include <pthread.h>
 #include <sys/time.h>
+#include <stdint.h>
 #include <time.h>
 #include <errno.h>
 
 
 typedef struct private_mutex_t private_mutex_t;
-typedef struct private_n_mutex_t private_n_mutex_t;
 typedef struct private_r_mutex_t private_r_mutex_t;
 typedef struct private_condvar_t private_condvar_t;
 
@@ -45,6 +45,11 @@ struct private_mutex_t {
         * wrapped pthread mutex
         */
        pthread_mutex_t mutex;
+       
+       /**
+        * is this a recursiv emutex, implementing private_r_mutex_t?
+        */
+       bool recursive;
 };
 
 /**
@@ -53,7 +58,7 @@ struct private_mutex_t {
 struct private_r_mutex_t {
 
        /**
-        * public functions
+        * Extends private_mutex_t
         */
        private_mutex_t generic;
        
@@ -63,9 +68,9 @@ struct private_r_mutex_t {
        pthread_t thread;
        
        /**
-        * times we have locked the lock
+        * times we have locked the lock, stored per thread
         */
-       int times;
+       pthread_key_t times;
 };
 
 /**
@@ -115,12 +120,19 @@ static void lock_r(private_r_mutex_t *this)
 
        if (this->thread == self)
        {
-               this->times++;
-               return;
+               uintptr_t times;
+               
+               /* times++ */
+               times = (uintptr_t)pthread_getspecific(this->times);
+               pthread_setspecific(this->times, (void*)times + 1);
+       }
+       else
+       {
+               lock(&this->generic);
+               this->thread = self;
+               /* times = 1 */
+               pthread_setspecific(this->times, (void*)1);
        }
-       lock(&this->generic);
-       this->thread = self;
-       this->times = 1;
 }
 
 /**
@@ -128,7 +140,13 @@ static void lock_r(private_r_mutex_t *this)
  */
 static void unlock_r(private_r_mutex_t *this)
 {
-       if (--this->times == 0)
+       uintptr_t times;
+
+       /* times-- */
+       times = (uintptr_t)pthread_getspecific(this->times);
+       pthread_setspecific(this->times, (void*)--times);
+       
+       if (times == 0)
        {
                this->thread = 0;
                unlock(&this->generic);
@@ -144,6 +162,16 @@ static void mutex_destroy(private_mutex_t *this)
        free(this);
 }
 
+/**
+ * Implementation of mutex_t.destroy for recursive mutex'
+ */
+static void mutex_destroy_r(private_r_mutex_t *this)
+{
+       pthread_mutex_destroy(&this->generic.mutex);
+       pthread_key_delete(this->times);
+       free(this);
+}
+
 /*
  * see header file
  */
@@ -154,15 +182,16 @@ mutex_t *mutex_create(mutex_type_t type)
                case MUTEX_RECURSIVE:
                {
                        private_r_mutex_t *this = malloc_thing(private_r_mutex_t);
-       
+                       
                        this->generic.public.lock = (void(*)(mutex_t*))lock_r;
                        this->generic.public.unlock = (void(*)(mutex_t*))unlock_r;
-                       this->generic.public.destroy = (void(*)(mutex_t*))mutex_destroy;        
-       
+                       this->generic.public.destroy = (void(*)(mutex_t*))mutex_destroy_r;      
+                       
                        pthread_mutex_init(&this->generic.mutex, NULL);
+                       pthread_key_create(&this->times, NULL);
+                       this->generic.recursive = TRUE;
                        this->thread = 0;
-                       this->times = 0;
-       
+                       
                        return &this->generic.public;
                }
                case MUTEX_DEFAULT:
@@ -173,9 +202,10 @@ mutex_t *mutex_create(mutex_type_t type)
                        this->public.lock = (void(*)(mutex_t*))lock;
                        this->public.unlock = (void(*)(mutex_t*))unlock;
                        this->public.destroy = (void(*)(mutex_t*))mutex_destroy;
-               
+                       
                        pthread_mutex_init(&this->mutex, NULL);
-               
+                       this->recursive = FALSE;
+                       
                        return &this->public;
                }
        }
@@ -186,7 +216,19 @@ mutex_t *mutex_create(mutex_type_t type)
  */
 static void wait(private_condvar_t *this, private_mutex_t *mutex)
 {
-       pthread_cond_wait(&this->condvar, &mutex->mutex);
+       if (mutex->recursive)
+       {
+               private_r_mutex_t* recursive = (private_r_mutex_t*)mutex;
+               
+               /* mutex owner gets cleared during condvar wait */
+               recursive->thread = 0;
+               pthread_cond_wait(&this->condvar, &mutex->mutex);
+               recursive->thread = pthread_self();
+       }
+       else
+       {
+               pthread_cond_wait(&this->condvar, &mutex->mutex);
+       }
 }
 
 /**
@@ -198,6 +240,7 @@ static bool timed_wait(private_condvar_t *this, private_mutex_t *mutex,
        struct timespec ts;
        struct timeval tv;
        u_int s, ms;
+       bool timed_out;
        
        gettimeofday(&tv, NULL);
        
@@ -211,8 +254,21 @@ static bool timed_wait(private_condvar_t *this, private_mutex_t *mutex,
                ts.tv_nsec -= 1000000000;
                ts.tv_sec++;
        }
-       return (pthread_cond_timedwait(&this->condvar, &mutex->mutex,
-                                                                  &ts) == ETIMEDOUT);
+       if (mutex->recursive)
+       {
+               private_r_mutex_t* recursive = (private_r_mutex_t*)mutex;
+               
+               recursive->thread = 0;
+               timed_out = pthread_cond_timedwait(&this->condvar, &mutex->mutex,
+                                                                                  &ts) == ETIMEDOUT;
+               recursive->thread = pthread_self();
+       }
+       else
+       {
+               timed_out = pthread_cond_timedwait(&this->condvar, &mutex->mutex,
+                                                                                  &ts) == ETIMEDOUT;
+       }
+       return timed_out;
 }
 
 /**