added threads to support multiple simultaneous stroke requests
authorMartin Willi <martin@strongswan.org>
Fri, 2 Feb 2007 07:30:19 +0000 (07:30 -0000)
committerMartin Willi <martin@strongswan.org>
Fri, 2 Feb 2007 07:30:19 +0000 (07:30 -0000)
src/charon/threads/stroke_interface.c

index 887ce45..9509473 100755 (executable)
@@ -150,7 +150,7 @@ static x509_t* load_end_certificate(const char *filename, identification_t **idp
 /**
  * Add a connection to the configuration list
  */
-static void stroke_add_conn(private_stroke_t *this, stroke_msg_t *msg)
+static void stroke_add_conn(stroke_msg_t *msg, FILE *out)
 {
        connection_t *connection;
        policy_t *policy;
@@ -460,7 +460,7 @@ destroy_hosts:
 /**
  * Delete a connection from the list
  */
-static void stroke_del_conn(private_stroke_t *this, stroke_msg_t *msg)
+static void stroke_del_conn(stroke_msg_t *msg, FILE *out)
 {
        status_t status;
        
@@ -472,18 +472,18 @@ static void stroke_del_conn(private_stroke_t *this, stroke_msg_t *msg)
        charon->policies->delete_policy(charon->policies, msg->del_conn.name);
        if (status == SUCCESS)
        {
-               fprintf(this->out, "deleted connection '%s'\n", msg->del_conn.name);
+               fprintf(out, "deleted connection '%s'\n", msg->del_conn.name);
        }
        else
        {
-               fprintf(this->out, "no connection named '%s'\n", msg->del_conn.name);
+               fprintf(out, "no connection named '%s'\n", msg->del_conn.name);
        }
 }
 
 /**
  * initiate a connection by name
  */
-static void stroke_initiate(private_stroke_t *this, stroke_msg_t *msg)
+static void stroke_initiate(stroke_msg_t *msg, FILE *out)
 {
        initiate_job_t *job;
        connection_t *connection;
@@ -500,7 +500,7 @@ static void stroke_initiate(private_stroke_t *this, stroke_msg_t *msg)
        {
                if (msg->output_verbosity >= 0)
                {
-                       fprintf(this->out, "no connection named '%s'\n", msg->initiate.name);
+                       fprintf(out, "no connection named '%s'\n", msg->initiate.name);
                }
                return;
        }
@@ -516,7 +516,7 @@ static void stroke_initiate(private_stroke_t *this, stroke_msg_t *msg)
        {
                if (msg->output_verbosity >= 0)
                {
-                       fprintf(this->out, "no policy named '%s'\n", msg->initiate.name);
+                       fprintf(out, "no policy named '%s'\n", msg->initiate.name);
                }
                connection->destroy(connection);
                return;
@@ -547,9 +547,9 @@ static void stroke_initiate(private_stroke_t *this, stroke_msg_t *msg)
                if ((init_ike_sa == NULL || ike_sa == init_ike_sa) &&
                        level <= msg->output_verbosity)
                {
-                       if (vfprintf(this->out, format, args) < 0 ||
-                               fprintf(this->out, "\n") < 0 ||
-                               fflush(this->out))
+                       if (vfprintf(out, format, args) < 0 ||
+                               fprintf(out, "\n") < 0 ||
+                               fflush(out))
                        {
                                charon->bus->set_listen_state(charon->bus, FALSE);
                                break;
@@ -583,7 +583,7 @@ static void stroke_initiate(private_stroke_t *this, stroke_msg_t *msg)
 /**
  * route/unroute a policy (install SPD entries)
  */
-static void stroke_route(private_stroke_t *this, stroke_msg_t *msg, bool route)
+static void stroke_route(stroke_msg_t *msg, FILE *out, bool route)
 {
        route_job_t *job;
        connection_t *connection;
@@ -599,7 +599,7 @@ static void stroke_route(private_stroke_t *this, stroke_msg_t *msg, bool route)
                                                                                                                         msg->route.name);
        if (connection == NULL)
        {
-               fprintf(this->out, "no connection named '%s'\n", msg->route.name);
+               fprintf(out, "no connection named '%s'\n", msg->route.name);
                return;
        }
        if (!connection->is_ikev2(connection))
@@ -612,11 +612,11 @@ static void stroke_route(private_stroke_t *this, stroke_msg_t *msg, bool route)
                                                                                                  msg->route.name);
        if (policy == NULL)
        {
-               fprintf(this->out, "no policy named '%s'\n", msg->route.name);
+               fprintf(out, "no policy named '%s'\n", msg->route.name);
                connection->destroy(connection);
                return;
        }
-       fprintf(this->out, "%s policy '%s'\n",
+       fprintf(out, "%s policy '%s'\n",
                        route ? "routing" : "unrouting", msg->route.name);
        job = route_job_create(connection, policy, route);
        charon->job_queue->add(charon->job_queue, (job_t*)job);
@@ -625,7 +625,7 @@ static void stroke_route(private_stroke_t *this, stroke_msg_t *msg, bool route)
 /**
  * terminate a connection by name
  */
-static void stroke_terminate(private_stroke_t *this, stroke_msg_t *msg)
+static void stroke_terminate(stroke_msg_t *msg, FILE *out)
 {
        pop_string(msg, &(msg->terminate.name));
        DBG1(DBG_CFG, "received stroke: terminate '%s'", msg->terminate.name);
@@ -636,7 +636,7 @@ static void stroke_terminate(private_stroke_t *this, stroke_msg_t *msg)
 /**
  * show status of daemon
  */
-static void stroke_statusall(private_stroke_t *this, stroke_msg_t *msg)
+static void stroke_statusall(stroke_msg_t *msg, FILE *out)
 {
        iterator_t *iterator;
        linked_list_t *list;
@@ -646,22 +646,22 @@ static void stroke_statusall(private_stroke_t *this, stroke_msg_t *msg)
        ike_sa_t *ike_sa;
        char *name = NULL;
 
-       leak_detective_status(this->out);
+       leak_detective_status(out);
        
-       fprintf(this->out, "Performance:\n");
-       fprintf(this->out, "  worker threads: %d idle of %d,",
+       fprintf(out, "Performance:\n");
+       fprintf(out, "  worker threads: %d idle of %d,",
                        charon->thread_pool->get_idle_threads(charon->thread_pool),
                        charon->thread_pool->get_pool_size(charon->thread_pool));
-       fprintf(this->out, " job queue load: %d,",
+       fprintf(out, " job queue load: %d,",
                        charon->job_queue->get_count(charon->job_queue));
-       fprintf(this->out, " scheduled events: %d\n",
+       fprintf(out, " scheduled events: %d\n",
                        charon->event_queue->get_count(charon->event_queue));
        list = charon->socket->create_local_address_list(charon->socket);
 
-       fprintf(this->out, "Listening on %d IP addresses:\n", list->get_count(list));
+       fprintf(out, "Listening on %d IP addresses:\n", list->get_count(list));
        while (list->remove_first(list, (void**)&host) == SUCCESS)
        {
-               fprintf(this->out, "  %H\n", host);
+               fprintf(out, "  %H\n", host);
                host->destroy(host);
        }
        list->destroy(list);
@@ -675,14 +675,14 @@ static void stroke_statusall(private_stroke_t *this, stroke_msg_t *msg)
        iterator = charon->connections->create_iterator(charon->connections);
        if (iterator->get_count(iterator) > 0)
        {
-               fprintf(this->out, "Connections:\n");
+               fprintf(out, "Connections:\n");
        }
        while (iterator->iterate(iterator, (void**)&connection))
        {
                if (connection->is_ikev2(connection)
                && (name == NULL || streq(name, connection->get_name(connection))))
                {
-                       fprintf(this->out, "%12s:  %H...%H\n",
+                       fprintf(out, "%12s:  %H...%H\n",
                                        connection->get_name(connection),
                                        connection->get_my_host(connection),
                                        connection->get_other_host(connection));
@@ -693,13 +693,13 @@ static void stroke_statusall(private_stroke_t *this, stroke_msg_t *msg)
        iterator = charon->policies->create_iterator(charon->policies);
        if (iterator->get_count(iterator) > 0)
        {
-               fprintf(this->out, "Policies:\n");
+               fprintf(out, "Policies:\n");
        }
        while (iterator->iterate(iterator, (void**)&policy))
        {
                if (name == NULL || streq(name, policy->get_name(policy)))
                {
-                       fprintf(this->out, "%12s:  '%D'...'%D'\n",
+                       fprintf(out, "%12s:  '%D'...'%D'\n",
                                        policy->get_name(policy),
                                        policy->get_my_id(policy),
                                        policy->get_other_id(policy));
@@ -710,7 +710,7 @@ static void stroke_statusall(private_stroke_t *this, stroke_msg_t *msg)
        iterator = charon->ike_sa_manager->create_iterator(charon->ike_sa_manager);
        if (iterator->get_count(iterator) > 0)
        {
-               fprintf(this->out, "Security Associations:\n");
+               fprintf(out, "Security Associations:\n");
        }
        while (iterator->iterate(iterator, (void**)&ike_sa))
        {
@@ -721,7 +721,7 @@ static void stroke_statusall(private_stroke_t *this, stroke_msg_t *msg)
                /* print IKE_SA */
                if (name == NULL || strncmp(name, ike_sa->get_name(ike_sa), strlen(name)) == 0)
                {
-                       fprintf(this->out, "%#K\n", ike_sa);
+                       fprintf(out, "%#K\n", ike_sa);
                        ike_sa_printed = TRUE;
                }
 
@@ -733,14 +733,14 @@ static void stroke_statusall(private_stroke_t *this, stroke_msg_t *msg)
                        /* print IKE_SA if its name differs from the CHILD_SA's name */
                        if (!ike_sa_printed && child_sa_match)
                        {
-                               fprintf(this->out, "%#K\n", ike_sa);
+                               fprintf(out, "%#K\n", ike_sa);
                                ike_sa_printed = TRUE;
                        }
 
                        /* print CHILD_SA */
                        if (child_sa_match)
                        {
-                               fprintf(this->out, "%#P\n", child_sa);
+                               fprintf(out, "%#P\n", child_sa);
                        }
                }
                children->destroy(children);
@@ -751,7 +751,7 @@ static void stroke_statusall(private_stroke_t *this, stroke_msg_t *msg)
 /**
  * show status of daemon
  */
-static void stroke_status(private_stroke_t *this, stroke_msg_t *msg)
+static void stroke_status(stroke_msg_t *msg, FILE *out)
 {
        iterator_t *iterator;
        ike_sa_t *ike_sa;
@@ -773,7 +773,7 @@ static void stroke_status(private_stroke_t *this, stroke_msg_t *msg)
                /* print IKE_SA */
                if (name == NULL || strncmp(name, ike_sa->get_name(ike_sa), strlen(name)) == 0)
                {
-                       fprintf(this->out, "%K\n", ike_sa);
+                       fprintf(out, "%K\n", ike_sa);
                        ike_sa_printed = TRUE;
                }
 
@@ -785,14 +785,14 @@ static void stroke_status(private_stroke_t *this, stroke_msg_t *msg)
                        /* print IKE_SA if its name differs from the CHILD_SA's name */
                        if (!ike_sa_printed && child_sa_match)
                        {
-                               fprintf(this->out, "%K\n", ike_sa);
+                               fprintf(out, "%K\n", ike_sa);
                                ike_sa_printed = TRUE;
                        }
 
                        /* print CHILD_SA */
                        if (child_sa_match)
                        {
-                               fprintf(this->out, "%P\n", child_sa);
+                               fprintf(out, "%P\n", child_sa);
                        }
                }
                children->destroy(children);
@@ -803,7 +803,7 @@ static void stroke_status(private_stroke_t *this, stroke_msg_t *msg)
 /**
  * list various information
  */
-static void stroke_list(private_stroke_t *this, stroke_msg_t *msg)
+static void stroke_list(stroke_msg_t *msg, FILE *out)
 {
        iterator_t *iterator;
        
@@ -814,19 +814,19 @@ static void stroke_list(private_stroke_t *this, stroke_msg_t *msg)
                iterator = charon->credentials->create_cert_iterator(charon->credentials);
                if (iterator->get_count(iterator))
                {
-                       fprintf(this->out, "\n");
-                       fprintf(this->out, "List of X.509 End Entity Certificates:\n");
-                       fprintf(this->out, "\n");
+                       fprintf(out, "\n");
+                       fprintf(out, "List of X.509 End Entity Certificates:\n");
+                       fprintf(out, "\n");
                }
                while (iterator->iterate(iterator, (void**)&cert))
                {
-                       fprintf(this->out, "%#Q", cert, msg->list.utc);
+                       fprintf(out, "%#Q", cert, msg->list.utc);
                        if (charon->credentials->has_rsa_private_key(
                                        charon->credentials, cert->get_public_key(cert)))
                        {
-                               fprintf(this->out, ", has private key");
+                               fprintf(out, ", has private key");
                        }
-                       fprintf(this->out, "\n");
+                       fprintf(out, "\n");
                        
                }
                iterator->destroy(iterator);
@@ -838,13 +838,13 @@ static void stroke_list(private_stroke_t *this, stroke_msg_t *msg)
                iterator = charon->credentials->create_cacert_iterator(charon->credentials);
                if (iterator->get_count(iterator))
                {
-                       fprintf(this->out, "\n");
-                       fprintf(this->out, "List of X.509 CA Certificates:\n");
-                       fprintf(this->out, "\n");
+                       fprintf(out, "\n");
+                       fprintf(out, "List of X.509 CA Certificates:\n");
+                       fprintf(out, "\n");
                }
                while (iterator->iterate(iterator, (void**)&cert))
                {
-                       fprintf(this->out, "%#Q\n", cert, msg->list.utc);
+                       fprintf(out, "%#Q\n", cert, msg->list.utc);
                }
                iterator->destroy(iterator);
        }
@@ -855,13 +855,13 @@ static void stroke_list(private_stroke_t *this, stroke_msg_t *msg)
                iterator = charon->credentials->create_crl_iterator(charon->credentials);
                if (iterator->get_count(iterator))
                {
-                       fprintf(this->out, "\n");
-                       fprintf(this->out, "List of X.509 CRLs:\n");
-                       fprintf(this->out, "\n");
+                       fprintf(out, "\n");
+                       fprintf(out, "List of X.509 CRLs:\n");
+                       fprintf(out, "\n");
                }
                while (iterator->iterate(iterator, (void**)&crl))
                {
-                       fprintf(this->out, "%#U\n", crl, msg->list.utc);
+                       fprintf(out, "%#U\n", crl, msg->list.utc);
                }
                iterator->destroy(iterator);
        }
@@ -870,7 +870,7 @@ static void stroke_list(private_stroke_t *this, stroke_msg_t *msg)
 /**
  * reread various information
  */
-static void stroke_reread(private_stroke_t *this, stroke_msg_t *msg)
+static void stroke_reread(stroke_msg_t *msg, FILE *out)
 {
        if (msg->reread.flags & REREAD_CACERTS)
        {
@@ -900,7 +900,7 @@ signal_t get_signal_from_logtype(char *type)
 /**
  * set the verbosity debug output
  */
-static void stroke_loglevel(private_stroke_t *this, stroke_msg_t *msg)
+static void stroke_loglevel(stroke_msg_t *msg, FILE *out)
 {
        signal_t signal;
        
@@ -911,7 +911,7 @@ static void stroke_loglevel(private_stroke_t *this, stroke_msg_t *msg)
        signal = get_signal_from_logtype(msg->loglevel.type);
        if (signal < 0)
        {
-               fprintf(this->out, "invalid type (%s)!\n", msg->loglevel.type);
+               fprintf(out, "invalid type (%s)!\n", msg->loglevel.type);
                return;
        }
        
@@ -920,17 +920,99 @@ static void stroke_loglevel(private_stroke_t *this, stroke_msg_t *msg)
 }
 
 /**
- * Implementation of private_stroke_t.stroke_receive.
+ * process a stroke request from the socket pointed by "fd"
  */
-static void stroke_receive(private_stroke_t *this)
+static void stroke_process(int *fd)
 {
        stroke_msg_t *msg;
        u_int16_t msg_length;
+       ssize_t bytes_read;
+       FILE *out;
+       int strokefd = *fd;
+       
+       /* peek the length */
+       bytes_read = recv(strokefd, &msg_length, sizeof(msg_length), MSG_PEEK);
+       if (bytes_read != sizeof(msg_length))
+       {
+               DBG1(DBG_CFG, "reading length of stroke message failed");
+               close(strokefd);
+               return;
+       }
+       
+       /* read message */
+       msg = malloc(msg_length);
+       bytes_read = recv(strokefd, msg, msg_length, 0);
+       if (bytes_read != msg_length)
+       {
+               DBG1(DBG_CFG, "reading stroke message failed: %m");
+               close(strokefd);
+               return;
+       }
+       
+       out = fdopen(dup(strokefd), "w");
+       if (out == NULL)
+       {
+               DBG1(DBG_CFG, "opening stroke output channel failed: %m");
+               close(strokefd);
+               free(msg);
+               return;
+       }
+       
+       DBG3(DBG_CFG, "stroke message %b", (void*)msg, msg_length);
+       
+       switch (msg->type)
+       {
+               case STR_INITIATE:
+                       stroke_initiate(msg, out);
+                       break;
+               case STR_ROUTE:
+                       stroke_route(msg, out, TRUE);
+                       break;
+               case STR_UNROUTE:
+                       stroke_route(msg, out, FALSE);
+                       break;
+               case STR_TERMINATE:
+                       stroke_terminate(msg, out);
+                       break;
+               case STR_STATUS:
+                       stroke_status(msg, out);
+                       break;
+               case STR_STATUS_ALL:
+                       stroke_statusall(msg, out);
+                       break;
+               case STR_ADD_CONN:
+                       stroke_add_conn(msg, out);
+                       break;
+               case STR_DEL_CONN:
+                       stroke_del_conn(msg, out);
+                       break;
+               case STR_LOGLEVEL:
+                       stroke_loglevel(msg, out);
+                       break;
+               case STR_LIST:
+                       stroke_list(msg, out);
+                       break;
+               case STR_REREAD:
+                       stroke_reread(msg, out);
+                       break;
+               default:
+                       DBG1(DBG_CFG, "received unknown stroke");
+       }
+       fclose(out);
+       close(strokefd);
+       free(msg);
+}
+
+/**
+ * Implementation of private_stroke_t.stroke_receive.
+ */
+static void stroke_receive(private_stroke_t *this)
+{
        struct sockaddr_un strokeaddr;
        int strokeaddrlen = sizeof(strokeaddr);
-       ssize_t bytes_read;
        int strokefd;
        int oldstate;
+       pthread_t thread;
        
        /* ignore sigpipe. writing over the pipe back to the console
         * only fails if SIGPIPE is ignored. */
@@ -939,7 +1021,7 @@ static void stroke_receive(private_stroke_t *this)
        /* disable cancellation by default */
        pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, NULL);
        
-       while (1)
+       while (TRUE)
        {
                /* wait for connections, but allow thread to terminate */
                pthread_setcancelstate(PTHREAD_CANCEL_ENABLE, &oldstate);
@@ -952,77 +1034,13 @@ static void stroke_receive(private_stroke_t *this)
                        continue;
                }
                
-               /* peek the length */
-               bytes_read = recv(strokefd, &msg_length, sizeof(msg_length), MSG_PEEK);
-               if (bytes_read != sizeof(msg_length))
-               {
-                       DBG1(DBG_CFG, "reading length of stroke message failed");
-                       close(strokefd);
-                       continue;
-               }
-               
-               /* read message */
-               msg = malloc(msg_length);
-               bytes_read = recv(strokefd, msg, msg_length, 0);
-               if (bytes_read != msg_length)
-               {
-                       DBG1(DBG_CFG, "reading stroke message failed: %m");
-                       close(strokefd);
-                       continue;
-               }
-               
-               this->out = fdopen(dup(strokefd), "w");
-               if (this->out == NULL)
-               {
-                       DBG1(DBG_CFG, "opening stroke output channel failed: %m");
-                       close(strokefd);
-                       free(msg);
-                       continue;
-               }
-               
-               DBG3(DBG_CFG, "stroke message %b", (void*)msg, msg_length);
-               
-               switch (msg->type)
-               {
-                       case STR_INITIATE:
-                               stroke_initiate(this, msg);
-                               break;
-                       case STR_ROUTE:
-                               stroke_route(this, msg, TRUE);
-                               break;
-                       case STR_UNROUTE:
-                               stroke_route(this, msg, FALSE);
-                               break;
-                       case STR_TERMINATE:
-                               stroke_terminate(this, msg);
-                               break;
-                       case STR_STATUS:
-                               stroke_status(this, msg);
-                               break;
-                       case STR_STATUS_ALL:
-                               stroke_statusall(this, msg);
-                               break;
-                       case STR_ADD_CONN:
-                               stroke_add_conn(this, msg);
-                               break;
-                       case STR_DEL_CONN:
-                               stroke_del_conn(this, msg);
-                               break;
-                       case STR_LOGLEVEL:
-                               stroke_loglevel(this, msg);
-                               break;
-                       case STR_LIST:
-                               stroke_list(this, msg);
-                               break;
-                       case STR_REREAD:
-                               stroke_reread(this, msg);
-                               break;
-                       default:
-                               DBG1(DBG_CFG, "received unknown stroke");
+               /* handle request asynchronously */
+               if (pthread_create(&thread, NULL, (void*(*)(void*))stroke_process, (void*)&strokefd) != 0)
+               {               
+                       DBG1(DBG_CFG, "failed to spawn stroke thread: %m");
                }
-               fclose(this->out);
-               close(strokefd);
-               free(msg);
+               /* detach so the thread terminates cleanly */
+               pthread_detach(thread);
        }
 }