tls-server: Refactor writing of key share extensions
authorPascal Knecht <pascal.knecht@hsr.ch>
Mon, 21 Sep 2020 20:19:34 +0000 (22:19 +0200)
committerTobias Brunner <tobias@strongswan.org>
Fri, 12 Feb 2021 13:35:23 +0000 (14:35 +0100)
Client and server now share the same code to write this extension.

src/libtls/tls_peer.c
src/libtls/tls_server.c

index 4e5c2cc..7d6c1ff 100644 (file)
@@ -157,6 +157,10 @@ struct private_tls_peer_t {
        chunk_t cert_types;
 };
 
+/* Implemented in tls_server.c */
+bool tls_write_key_share(bio_writer_t **key_share, tls_named_group_t group,
+                                                diffie_hellman_t *dh);
+
 /**
  * Verify the DH group/key type requested by the server is valid.
  */
@@ -1202,7 +1206,6 @@ static status_t send_client_hello(private_tls_peer_t *this,
        enumerator_t *enumerator;
        int count, i, v;
        rng_t *rng;
-       chunk_t pub;
 
        htoun32(&this->client_random, time(NULL));
        rng = lib->crypto->create_rng(lib->crypto, RNG_WEAK);
@@ -1352,34 +1355,21 @@ static status_t send_client_hello(private_tls_peer_t *this,
        extensions->write_data16(extensions, signatures->get_buf(signatures));
        signatures->destroy(signatures);
 
-       if (this->dh)
+       if (this->tls->get_version_max(this->tls) >= TLS_1_3 &&
+               this->dh)
        {
                DBG2(DBG_TLS, "sending extension: %N",
                         tls_extension_names, TLS_EXT_KEY_SHARE);
-               if (!this->dh->get_my_public_value(this->dh, &pub))
+               extensions->write_uint16(extensions, TLS_EXT_KEY_SHARE);
+               if (!tls_write_key_share(&key_share, selected_curve, this->dh))
                {
                        this->alert->add(this->alert, TLS_FATAL, TLS_INTERNAL_ERROR);
                        extensions->destroy(extensions);
                        return NEED_MORE;
                }
-               extensions->write_uint16(extensions, TLS_EXT_KEY_SHARE);
-               key_share = bio_writer_create(pub.len + 6);
-               key_share->write_uint16(key_share, selected_curve);
-               if (selected_curve == TLS_CURVE25519 ||
-                       selected_curve == TLS_CURVE448)
-               {
-                       key_share->write_data16(key_share, pub);
-               }
-               else
-               {       /* classic format (see RFC 8446, section 4.2.8.2) */
-                       key_share->write_uint16(key_share, pub.len + 1);
-                       key_share->write_uint8(key_share, TLS_ANSI_UNCOMPRESSED);
-                       key_share->write_data(key_share, pub);
-               }
                key_share->wrap16(key_share);
                extensions->write_data16(extensions, key_share->get_buf(key_share));
                key_share->destroy(key_share);
-               free(pub.ptr);
        }
 
        writer->write_data16(writer, extensions->get_buf(extensions));
index c858252..fb897cf 100644 (file)
@@ -340,8 +340,7 @@ static status_t process_client_hello(private_tls_server_t *this,
                                {
                                        DBG1(DBG_TLS, "invalid %N extension",
                                                 tls_extension_names, extension_type);
-                                       this->alert->add(this->alert, TLS_FATAL,
-                                                                        TLS_DECODE_ERROR);
+                                       this->alert->add(this->alert, TLS_FATAL, TLS_DECODE_ERROR);
                                        extensions->destroy(extensions);
                                        extension->destroy(extension);
                                        return NEED_MORE;
@@ -916,14 +915,45 @@ METHOD(tls_handshake_t, process, status_t,
 }
 
 /**
+ * Write public key into key share extension
+ */
+bool tls_write_key_share(bio_writer_t **key_share, tls_named_group_t group,
+                                                diffie_hellman_t *dh)
+{
+       bio_writer_t *writer;
+       chunk_t pub;
+
+       if (!dh || !dh->get_my_public_value(dh, &pub))
+       {
+               return FALSE;
+       }
+       *key_share = writer = bio_writer_create(pub.len + 7);
+       writer->write_uint16(writer, group);
+       if (group == TLS_CURVE25519 ||
+               group == TLS_CURVE448)
+       {
+               writer->write_data16(writer, pub);
+       }
+       else
+       {       /* classic format (see RFC 8446, section 4.2.8.2) */
+               writer->write_uint16(writer, pub.len + 1);
+               writer->write_uint8(writer, TLS_ANSI_UNCOMPRESSED);
+               writer->write_data(writer, pub);
+       }
+       free(pub.ptr);
+       return TRUE;
+}
+
+/**
  * Send ServerHello message
  */
 static status_t send_server_hello(private_tls_server_t *this,
                                                        tls_handshake_type_t *type, bio_writer_t *writer)
 {
-       bio_writer_t *extensions, *key_share;
-       tls_version_t version = this->tls->get_version_max(this->tls);
-       chunk_t pub;
+       bio_writer_t *key_share, *extensions;
+       tls_version_t version;
+
+       version = this->tls->get_version_max(this->tls);
 
        /* cap legacy version at TLS 1.2 for middlebox compatibility */
        writer->write_uint16(writer, min(TLS_1_2, version));
@@ -948,36 +978,18 @@ static status_t send_server_hello(private_tls_server_t *this,
                extensions->write_uint16(extensions, 2);
                extensions->write_uint16(extensions, version);
 
-               if (this->dh)
-               {
-                       tls_named_group_t selected_curve = this->requested_curve;
+               DBG2(DBG_TLS, "sending extension: %N",
+                        tls_extension_names, TLS_EXT_KEY_SHARE);
+               extensions->write_uint16(extensions, TLS_EXT_KEY_SHARE);
 
-                       DBG2(DBG_TLS, "sending extension: %N",
-                                tls_extension_names, TLS_EXT_KEY_SHARE);
-                       if (!this->dh->get_my_public_value(this->dh, &pub))
-                       {
-                               this->alert->add(this->alert, TLS_FATAL, TLS_INTERNAL_ERROR);
-                               extensions->destroy(extensions);
-                               return NEED_MORE;
-                       }
-                       extensions->write_uint16(extensions, TLS_EXT_KEY_SHARE);
-                       key_share = bio_writer_create(pub.len + 6);
-                       key_share->write_uint16(key_share, selected_curve);
-                       if (selected_curve == TLS_CURVE25519 ||
-                               selected_curve == TLS_CURVE448)
-                       {
-                               key_share->write_data16(key_share, pub);
-                       }
-                       else
-                       {       /* classic format (see RFC 8446, section 4.2.8.2) */
-                               key_share->write_uint16(key_share, pub.len + 1);
-                               key_share->write_uint8(key_share, TLS_ANSI_UNCOMPRESSED);
-                               key_share->write_data(key_share, pub);
-                       }
-                       extensions->write_data16(extensions, key_share->get_buf(key_share));
-                       key_share->destroy(key_share);
-                       free(pub.ptr);
+               if (!tls_write_key_share(&key_share, this->requested_curve, this->dh))
+               {
+                       this->alert->add(this->alert, TLS_FATAL, TLS_INTERNAL_ERROR);
+                       extensions->destroy(extensions);
+                       return NEED_MORE;
                }
+               extensions->write_data16(extensions, key_share->get_buf(key_share));
+               key_share->destroy(key_share);
 
                writer->write_data16(writer, extensions->get_buf(extensions));
                extensions->destroy(extensions);