Support signing of RADIUS response messages
[strongswan.git] / src / libcharon / plugins / eap_radius / radius_message.c
index 11a1d8d..9d7bf3e 100644 (file)
@@ -74,7 +74,14 @@ ENUM_BEGIN(radius_message_code_names, RMC_ACCESS_REQUEST, RMC_ACCOUNTING_RESPONS
        "Accounting-Response");
 ENUM_NEXT(radius_message_code_names, RMC_ACCESS_CHALLENGE, RMC_ACCESS_CHALLENGE, RMC_ACCOUNTING_RESPONSE,
        "Access-Challenge");
-ENUM_END(radius_message_code_names, RMC_ACCESS_CHALLENGE);
+ENUM_NEXT(radius_message_code_names, RMC_DISCONNECT_REQUEST, RMC_COA_NAK, RMC_ACCESS_CHALLENGE,
+       "Disconnect-Request",
+       "Disconnect-ACK",
+       "Disconnect-NAK",
+       "CoA-Request",
+       "CoA-ACK",
+       "CoA-NAK");
+ENUM_END(radius_message_code_names, RMC_COA_NAK);
 
 ENUM(radius_attribute_type_names, RAT_USER_NAME, RAT_MIP6_HOME_LINK_PREFIX,
        "User-Name",
@@ -215,13 +222,8 @@ typedef struct {
        int left;
 } attribute_enumerator_t;
 
-
-/**
- * Implementation of attribute_enumerator_t.enumerate
- */
-static bool attribute_enumerate(attribute_enumerator_t *this,
-                                                               int *type, chunk_t *data)
-
+METHOD(enumerator_t, attribute_enumerate, bool,
+       attribute_enumerator_t *this, int *type, chunk_t *data)
 {
        if (this->left == 0)
        {
@@ -241,10 +243,8 @@ static bool attribute_enumerate(attribute_enumerator_t *this,
        return TRUE;
 }
 
-/**
- * Implementation of radius_message_t.create_enumerator
- */
-static enumerator_t* create_enumerator(private_radius_message_t *this)
+METHOD(radius_message_t, create_enumerator, enumerator_t*,
+       private_radius_message_t *this)
 {
        attribute_enumerator_t *e;
 
@@ -252,20 +252,19 @@ static enumerator_t* create_enumerator(private_radius_message_t *this)
        {
                return enumerator_create_empty();
        }
-
-       e = malloc_thing(attribute_enumerator_t);
-       e->public.enumerate = (void*)attribute_enumerate;
-       e->public.destroy = (void*)free;
-       e->next = (rattr_t*)this->msg->attributes;
-       e->left = ntohs(this->msg->length) - sizeof(rmsg_t);
+       INIT(e,
+               .public = {
+                       .enumerate = (void*)_attribute_enumerate,
+                       .destroy = (void*)free,
+               },
+               .next = (rattr_t*)this->msg->attributes,
+               .left = ntohs(this->msg->length) - sizeof(rmsg_t),
+       );
        return &e->public;
 }
 
-/**
- * Implementation of radius_message_t.add
- */
-static void add(private_radius_message_t *this, radius_attribute_type_t type,
-                               chunk_t data)
+METHOD(radius_message_t, add, void,
+       private_radius_message_t *this, radius_attribute_type_t type, chunk_t data)
 {
        rattr_t *attribute;
 
@@ -279,29 +278,45 @@ static void add(private_radius_message_t *this, radius_attribute_type_t type,
        this->msg->length = htons(ntohs(this->msg->length) + attribute->length);
 }
 
-/**
- * Implementation of radius_message_t.sign
- */
-static void sign(private_radius_message_t *this, rng_t *rng, signer_t *signer)
+METHOD(radius_message_t, sign, void,
+       private_radius_message_t *this, u_int8_t *req_auth, chunk_t secret,
+       hasher_t *hasher, signer_t *signer, rng_t *rng)
 {
-       char buf[HASH_SIZE_MD5];
+       if (rng == NULL)
+       {
+               chunk_t msg;
+
+               if (req_auth)
+               {
+                       memcpy(this->msg->authenticator, req_auth, HASH_SIZE_MD5);
+               }
+               else
+               {
+                       memset(this->msg->authenticator, 0, sizeof(this->msg->authenticator));
+               }
+               msg = chunk_create((u_char*)this->msg, ntohs(this->msg->length));
+               hasher->get_hash(hasher, msg, NULL);
+               hasher->get_hash(hasher, secret, this->msg->authenticator);
+       }
+       else
+       {
+               char buf[HASH_SIZE_MD5];
 
-       /* build Request-Authenticator */
-       rng->get_bytes(rng, HASH_SIZE_MD5, this->msg->authenticator);
+               /* build Request-Authenticator */
+               rng->get_bytes(rng, HASH_SIZE_MD5, this->msg->authenticator);
 
-       /* build Message-Authenticator attribute, using 16 null bytes */
-       memset(buf, 0, sizeof(buf));
-       add(this, RAT_MESSAGE_AUTHENTICATOR, chunk_create(buf, sizeof(buf)));
-       signer->get_signature(signer,
+               /* build Message-Authenticator attribute, using 16 null bytes */
+               memset(buf, 0, sizeof(buf));
+               add(this, RAT_MESSAGE_AUTHENTICATOR, chunk_create(buf, sizeof(buf)));
+               signer->get_signature(signer,
                                chunk_create((u_char*)this->msg, ntohs(this->msg->length)),
                                ((u_char*)this->msg) + ntohs(this->msg->length) - HASH_SIZE_MD5);
+       }
 }
 
-/**
- * Implementation of radius_message_t.verify
- */
-static bool verify(private_radius_message_t *this, u_int8_t *req_auth,
-                                  chunk_t secret, hasher_t *hasher, signer_t *signer)
+METHOD(radius_message_t, verify, bool,
+       private_radius_message_t *this, u_int8_t *req_auth, chunk_t secret,
+       hasher_t *hasher, signer_t *signer)
 {
        char buf[HASH_SIZE_MD5], res_auth[HASH_SIZE_MD5];
        enumerator_t *enumerator;
@@ -311,7 +326,14 @@ static bool verify(private_radius_message_t *this, u_int8_t *req_auth,
 
        /* replace Response by Request Authenticator for verification */
        memcpy(res_auth, this->msg->authenticator, HASH_SIZE_MD5);
-       memcpy(this->msg->authenticator, req_auth, HASH_SIZE_MD5);
+       if (req_auth)
+       {
+               memcpy(this->msg->authenticator, req_auth, HASH_SIZE_MD5);
+       }
+       else
+       {
+               memset(this->msg->authenticator, 0, HASH_SIZE_MD5);
+       }
        msg = chunk_create((u_char*)this->msg, ntohs(this->msg->length));
 
        /* verify Response-Authenticator */
@@ -369,51 +391,39 @@ static bool verify(private_radius_message_t *this, u_int8_t *req_auth,
        return TRUE;
 }
 
-/**
- * Implementation of radius_message_t.get_code
- */
-static radius_message_code_t get_code(private_radius_message_t *this)
+METHOD(radius_message_t, get_code, radius_message_code_t,
+       private_radius_message_t *this)
 {
        return this->msg->code;
 }
 
-/**
- * Implementation of radius_message_t.get_identifier
- */
-static u_int8_t get_identifier(private_radius_message_t *this)
+METHOD(radius_message_t, get_identifier, u_int8_t,
+       private_radius_message_t *this)
 {
        return this->msg->identifier;
 }
 
-/**
- * Implementation of radius_message_t.set_identifier
- */
-static void set_identifier(private_radius_message_t *this, u_int8_t identifier)
+METHOD(radius_message_t, set_identifier, void,
+       private_radius_message_t *this, u_int8_t identifier)
 {
        this->msg->identifier = identifier;
 }
 
-/**
- * Implementation of radius_message_t.get_authenticator
- */
-static u_int8_t* get_authenticator(private_radius_message_t *this)
+METHOD(radius_message_t, get_authenticator, u_int8_t*,
+       private_radius_message_t *this)
 {
        return this->msg->authenticator;
 }
 
 
-/**
- * Implementation of radius_message_t.get_encoding
- */
-static chunk_t get_encoding(private_radius_message_t *this)
+METHOD(radius_message_t, get_encoding, chunk_t,
+       private_radius_message_t *this)
 {
        return chunk_create((u_char*)this->msg, ntohs(this->msg->length));
 }
 
-/**
- * Implementation of radius_message_t.destroy.
- */
-static void destroy(private_radius_message_t *this)
+METHOD(radius_message_t, destroy, void,
+       private_radius_message_t *this)
 {
        free(this->msg);
        free(this);
@@ -422,20 +432,24 @@ static void destroy(private_radius_message_t *this)
 /**
  * Generic constructor
  */
-static private_radius_message_t *radius_message_create()
+static private_radius_message_t *radius_message_create_empty()
 {
-       private_radius_message_t *this = malloc_thing(private_radius_message_t);
-
-       this->public.create_enumerator = (enumerator_t*(*)(radius_message_t*))create_enumerator;
-       this->public.add = (void(*)(radius_message_t*, radius_attribute_type_t,chunk_t))add;
-       this->public.get_code = (radius_message_code_t(*)(radius_message_t*))get_code;
-       this->public.get_identifier = (u_int8_t(*)(radius_message_t*))get_identifier;
-       this->public.set_identifier = (void(*)(radius_message_t*, u_int8_t identifier))set_identifier;
-       this->public.get_authenticator = (u_int8_t*(*)(radius_message_t*))get_authenticator;
-       this->public.get_encoding = (chunk_t(*)(radius_message_t*))get_encoding;
-       this->public.sign = (void(*)(radius_message_t*, rng_t *rng, signer_t *signer))sign;
-       this->public.verify = (bool(*)(radius_message_t*, u_int8_t *req_auth, chunk_t secret, hasher_t *hasher, signer_t *signer))verify;
-       this->public.destroy = (void(*)(radius_message_t*))destroy;
+       private_radius_message_t *this;
+
+       INIT(this,
+               .public = {
+                       .create_enumerator = _create_enumerator,
+                       .add = _add,
+                       .get_code = _get_code,
+                       .get_identifier = _get_identifier,
+                       .set_identifier = _set_identifier,
+                       .get_authenticator = _get_authenticator,
+                       .get_encoding = _get_encoding,
+                       .sign = _sign,
+                       .verify = _verify,
+                       .destroy = _destroy,
+               },
+       );
 
        return this;
 }
@@ -443,14 +457,15 @@ static private_radius_message_t *radius_message_create()
 /**
  * See header
  */
-radius_message_t *radius_message_create_request()
+radius_message_t *radius_message_create(radius_message_code_t code)
 {
-       private_radius_message_t *this = radius_message_create();
+       private_radius_message_t *this = radius_message_create_empty();
 
-       this->msg = malloc_thing(rmsg_t);
-       this->msg->code = RMC_ACCESS_REQUEST;
-       this->msg->identifier = 0;
-       this->msg->length = htons(sizeof(rmsg_t));
+       INIT(this->msg,
+               .code = code,
+               .identifier = 0,
+               .length = htons(sizeof(rmsg_t)),
+       );
 
        return &this->public;
 }
@@ -458,9 +473,9 @@ radius_message_t *radius_message_create_request()
 /**
  * See header
  */
-radius_message_t *radius_message_parse_response(chunk_t data)
+radius_message_t *radius_message_parse(chunk_t data)
 {
-       private_radius_message_t *this = radius_message_create();
+       private_radius_message_t *this = radius_message_create_empty();
 
        this->msg = malloc(data.len);
        memcpy(this->msg, data.ptr, data.len);