Add locking to IMC/IMV managers to add/remove IMC/IMVs on the fly
authorMartin Willi <martin@revosec.ch>
Tue, 20 Nov 2012 13:34:00 +0000 (14:34 +0100)
committerMartin Willi <martin@revosec.ch>
Fri, 30 Nov 2012 14:49:23 +0000 (15:49 +0100)
src/libcharon/plugins/tnc_imc/tnc_imc_manager.c
src/libcharon/plugins/tnc_imv/tnc_imv_manager.c

index 544270a..d2fce6f 100644 (file)
 
 #include <tncifimc.h>
 
-#include <collections/linked_list.h>
-#include <utils/debug.h>
 #include <daemon.h>
+#include <utils/debug.h>
+#include <threading/rwlock.h>
+#include <collections/linked_list.h>
 
 typedef struct private_tnc_imc_manager_t private_tnc_imc_manager_t;
 
@@ -41,6 +42,11 @@ struct private_tnc_imc_manager_t {
        linked_list_t *imcs;
 
        /**
+        * Lock to access IMC list
+        */
+       rwlock_t *lock;
+
+       /**
         * Next IMC ID to be assigned
         */
        TNC_IMCID next_imc_id;
@@ -58,8 +64,10 @@ METHOD(imc_manager_t, add, bool,
                DBG1(DBG_TNC, "IMC \"%s\" failed to initialize", imc->get_name(imc));
                return FALSE;
        }
+       this->lock->write_lock(this->lock);
        this->imcs->insert_last(this->imcs, imc);
        this->next_imc_id++;
+       this->lock->unlock(this->lock);
 
        if (imc->provide_bind_function(imc->get_id(imc),
                                                                   TNC_TNCC_BindFunction) != TNC_RESULT_SUCCESS)
@@ -70,7 +78,9 @@ METHOD(imc_manager_t, add, bool,
                }
                DBG1(DBG_TNC, "IMC \"%s\" failed to obtain bind function",
                         imc->get_name(imc));
+               this->lock->write_lock(this->lock);
                this->imcs->remove_last(this->imcs, (void**)&imc);
+               this->lock->unlock(this->lock);
                return FALSE;
        }
        return TRUE;
@@ -82,6 +92,7 @@ METHOD(imc_manager_t, remove_, imc_t*,
        enumerator_t *enumerator;
        imc_t *imc, *removed_imc = NULL;
 
+       this->lock->write_lock(this->lock);
        enumerator = this->imcs->create_enumerator(this->imcs);
        while (enumerator->enumerate(enumerator, &imc))
        {
@@ -93,6 +104,7 @@ METHOD(imc_manager_t, remove_, imc_t*,
                }
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
 
        return removed_imc;
 }
@@ -154,6 +166,7 @@ METHOD(imc_manager_t, is_registered, bool,
        imc_t *imc;
        bool found = FALSE;
 
+       this->lock->read_lock(this->lock);
        enumerator = this->imcs->create_enumerator(this->imcs);
        while (enumerator->enumerate(enumerator, &imc))
        {
@@ -164,6 +177,7 @@ METHOD(imc_manager_t, is_registered, bool,
                }
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
 
        return found;
 }
@@ -175,6 +189,7 @@ METHOD(imc_manager_t, reserve_id, bool,
        imc_t *imc;
        bool found = FALSE;
 
+       this->lock->write_lock(this->lock);
        enumerator = this->imcs->create_enumerator(this->imcs);
        while (enumerator->enumerate(enumerator, &imc))
        {
@@ -189,6 +204,7 @@ METHOD(imc_manager_t, reserve_id, bool,
                }
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
 
        return found;
 }
@@ -207,6 +223,7 @@ METHOD(imc_manager_t, notify_connection_change, void,
        enumerator_t *enumerator;
        imc_t *imc;
 
+       this->lock->read_lock(this->lock);
        enumerator = this->imcs->create_enumerator(this->imcs);
        while (enumerator->enumerate(enumerator, &imc))
        {
@@ -216,6 +233,7 @@ METHOD(imc_manager_t, notify_connection_change, void,
                }
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
 }
 
 METHOD(imc_manager_t, begin_handshake, void,
@@ -224,12 +242,14 @@ METHOD(imc_manager_t, begin_handshake, void,
        enumerator_t *enumerator;
        imc_t *imc;
 
+       this->lock->read_lock(this->lock);
        enumerator = this->imcs->create_enumerator(this->imcs);
        while (enumerator->enumerate(enumerator, &imc))
        {
                imc->begin_handshake(imc->get_id(imc), id);
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
 }
 
 METHOD(imc_manager_t, set_message_types, TNC_Result,
@@ -241,6 +261,7 @@ METHOD(imc_manager_t, set_message_types, TNC_Result,
        imc_t *imc;
        TNC_Result result = TNC_RESULT_FATAL;
 
+       this->lock->read_lock(this->lock);
        enumerator = this->imcs->create_enumerator(this->imcs);
        while (enumerator->enumerate(enumerator, &imc))
        {
@@ -252,6 +273,7 @@ METHOD(imc_manager_t, set_message_types, TNC_Result,
                }
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
        return result;
 }
 
@@ -265,6 +287,7 @@ METHOD(imc_manager_t, set_message_types_long, TNC_Result,
        imc_t *imc;
        TNC_Result result = TNC_RESULT_FATAL;
 
+       this->lock->read_lock(this->lock);
        enumerator = this->imcs->create_enumerator(this->imcs);
        while (enumerator->enumerate(enumerator, &imc))
        {
@@ -277,6 +300,7 @@ METHOD(imc_manager_t, set_message_types_long, TNC_Result,
                }
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
        return result;
 }
 
@@ -296,11 +320,12 @@ METHOD(imc_manager_t, receive_message, void,
        enumerator_t *enumerator;
        imc_t *imc;
 
+       this->lock->read_lock(this->lock);
        enumerator = this->imcs->create_enumerator(this->imcs);
        while (enumerator->enumerate(enumerator, &imc))
        {
                if (imc->type_supported(imc, msg_vid, msg_subtype) &&
-                  (!excl || (excl && imc->has_id(imc, dst_imc_id)) ))
+                       (!excl || (excl && imc->has_id(imc, dst_imc_id))))
                {
                        if (imc->receive_message_long && src_imv_id)
                        {
@@ -322,6 +347,8 @@ METHOD(imc_manager_t, receive_message, void,
                }
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
+
        if (!type_supported)
        {
                DBG2(DBG_TNC, "message type 0x%06x/0x%08x not supported by any IMC",
@@ -335,6 +362,7 @@ METHOD(imc_manager_t, batch_ending, void,
        enumerator_t *enumerator;
        imc_t *imc;
 
+       this->lock->read_lock(this->lock);
        enumerator = this->imcs->create_enumerator(this->imcs);
        while (enumerator->enumerate(enumerator, &imc))
        {
@@ -344,6 +372,7 @@ METHOD(imc_manager_t, batch_ending, void,
                }
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
 }
 
 METHOD(imc_manager_t, destroy, void,
@@ -362,6 +391,7 @@ METHOD(imc_manager_t, destroy, void,
                imc->destroy(imc);
        }
        this->imcs->destroy(this->imcs);
+       this->lock->destroy(this->lock);
        free(this);
 }
 
@@ -390,6 +420,7 @@ imc_manager_t* tnc_imc_manager_create(void)
                        .destroy = _destroy,
                },
                .imcs = linked_list_create(),
+               .lock = rwlock_create(RWLOCK_TYPE_DEFAULT),
                .next_imc_id = 1,
        );
 
index 5af6a72..308285d 100644 (file)
@@ -29,9 +29,9 @@
 #include <fcntl.h>
 
 #include <daemon.h>
-#include <utils/lexparser.h>
 #include <utils/debug.h>
-#include <threading/mutex.h>
+#include <threading/rwlock.h>
+#include <collections/linked_list.h>
 
 typedef struct private_tnc_imv_manager_t private_tnc_imv_manager_t;
 
@@ -51,6 +51,11 @@ struct private_tnc_imv_manager_t {
        linked_list_t *imvs;
 
        /**
+        * Lock for IMV list
+        */
+       rwlock_t *lock;
+
+       /**
         * Next IMV ID to be assigned
         */
        TNC_IMVID next_imv_id;
@@ -73,8 +78,10 @@ METHOD(imv_manager_t, add, bool,
                DBG1(DBG_TNC, "IMV \"%s\" failed to initialize", imv->get_name(imv));
                return FALSE;
        }
+       this->lock->write_lock(this->lock);
        this->imvs->insert_last(this->imvs, imv);
        this->next_imv_id++;
+       this->lock->unlock(this->lock);
 
        if (imv->provide_bind_function(imv->get_id(imv),
                                                                   TNC_TNCS_BindFunction) != TNC_RESULT_SUCCESS)
@@ -85,7 +92,9 @@ METHOD(imv_manager_t, add, bool,
                }
                DBG1(DBG_TNC, "IMV \"%s\" failed to obtain bind function",
                         imv->get_name(imv));
+               this->lock->write_lock(this->lock);
                this->imvs->remove_last(this->imvs, (void**)&imv);
+               this->lock->unlock(this->lock);
                return FALSE;
        }
        return TRUE;
@@ -97,6 +106,7 @@ METHOD(imv_manager_t, remove_, imv_t*,
        enumerator_t *enumerator;
        imv_t *imv, *removed_imv = NULL;
 
+       this->lock->write_lock(this->lock);
        enumerator = this->imvs->create_enumerator(this->imvs);
        while (enumerator->enumerate(enumerator, &imv))
        {
@@ -108,6 +118,7 @@ METHOD(imv_manager_t, remove_, imv_t*,
                }
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
 
        return removed_imv;
 }
@@ -169,6 +180,7 @@ METHOD(imv_manager_t, is_registered, bool,
        imv_t *imv;
        bool found = FALSE;
 
+       this->lock->read_lock(this->lock);
        enumerator = this->imvs->create_enumerator(this->imvs);
        while (enumerator->enumerate(enumerator, &imv))
        {
@@ -179,6 +191,7 @@ METHOD(imv_manager_t, is_registered, bool,
                }
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
 
        return found;
 }
@@ -190,6 +203,7 @@ METHOD(imv_manager_t, reserve_id, bool,
        imv_t *imv;
        bool found = FALSE;
 
+       this->lock->write_lock(this->lock);
        enumerator = this->imvs->create_enumerator(this->imvs);
        while (enumerator->enumerate(enumerator, &imv))
        {
@@ -204,6 +218,7 @@ METHOD(imv_manager_t, reserve_id, bool,
                }
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
 
        return found;
 }
@@ -283,6 +298,7 @@ METHOD(imv_manager_t, notify_connection_change, void,
        enumerator_t *enumerator;
        imv_t *imv;
 
+       this->lock->read_lock(this->lock);
        enumerator = this->imvs->create_enumerator(this->imvs);
        while (enumerator->enumerate(enumerator, &imv))
        {
@@ -292,6 +308,7 @@ METHOD(imv_manager_t, notify_connection_change, void,
                }
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
 }
 
 METHOD(imv_manager_t, set_message_types, TNC_Result,
@@ -303,6 +320,7 @@ METHOD(imv_manager_t, set_message_types, TNC_Result,
        imv_t *imv;
        TNC_Result result = TNC_RESULT_FATAL;
 
+       this->lock->read_lock(this->lock);
        enumerator = this->imvs->create_enumerator(this->imvs);
        while (enumerator->enumerate(enumerator, &imv))
        {
@@ -314,6 +332,7 @@ METHOD(imv_manager_t, set_message_types, TNC_Result,
                }
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
        return result;
 }
 
@@ -327,6 +346,7 @@ METHOD(imv_manager_t, set_message_types_long, TNC_Result,
        imv_t *imv;
        TNC_Result result = TNC_RESULT_FATAL;
 
+       this->lock->read_lock(this->lock);
        enumerator = this->imvs->create_enumerator(this->imvs);
        while (enumerator->enumerate(enumerator, &imv))
        {
@@ -339,6 +359,7 @@ METHOD(imv_manager_t, set_message_types_long, TNC_Result,
                }
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
        return result;
 }
 
@@ -348,12 +369,14 @@ METHOD(imv_manager_t, solicit_recommendation, void,
        enumerator_t *enumerator;
        imv_t *imv;
 
+       this->lock->read_lock(this->lock);
        enumerator = this->imvs->create_enumerator(this->imvs);
        while (enumerator->enumerate(enumerator, &imv))
        {
                imv->solicit_recommendation(imv->get_id(imv), id);
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
 }
 
 METHOD(imv_manager_t, receive_message, void,
@@ -374,11 +397,12 @@ METHOD(imv_manager_t, receive_message, void,
 
        msg_type = (msg_vid << 8) | msg_subtype;
 
+       this->lock->read_lock(this->lock);
        enumerator = this->imvs->create_enumerator(this->imvs);
        while (enumerator->enumerate(enumerator, &imv))
        {
                if (imv->type_supported(imv, msg_vid, msg_subtype) &&
-                  (!excl || (excl && imv->has_id(imv, dst_imv_id)) ))
+                       (!excl || (excl && imv->has_id(imv, dst_imv_id))))
                {
                        if (imv->receive_message_long && src_imc_id)
                        {
@@ -400,6 +424,8 @@ METHOD(imv_manager_t, receive_message, void,
                }
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
+
        if (!type_supported)
        {
                DBG2(DBG_TNC, "message type 0x%06x/0x%08x not supported by any IMV",
@@ -413,6 +439,7 @@ METHOD(imv_manager_t, batch_ending, void,
        enumerator_t *enumerator;
        imv_t *imv;
 
+       this->lock->read_lock(this->lock);
        enumerator = this->imvs->create_enumerator(this->imvs);
        while (enumerator->enumerate(enumerator, &imv))
        {
@@ -422,9 +449,9 @@ METHOD(imv_manager_t, batch_ending, void,
                }
        }
        enumerator->destroy(enumerator);
+       this->lock->unlock(this->lock);
 }
 
-
 METHOD(imv_manager_t, destroy, void,
        private_tnc_imv_manager_t *this)
 {
@@ -441,6 +468,7 @@ METHOD(imv_manager_t, destroy, void,
                imv->destroy(imv);
        }
        this->imvs->destroy(this->imvs);
+       this->lock->destroy(this->lock);
        free(this);
 }
 
@@ -472,6 +500,7 @@ imv_manager_t* tnc_imv_manager_create(void)
                        .destroy = _destroy,
                },
                .imvs = linked_list_create(),
+               .lock = rwlock_create(RWLOCK_TYPE_DEFAULT),
                .next_imv_id = 1,
        );