aesni: Use 4-way parallel AES-NI instructions for CBC decryption
authorMartin Willi <martin@revosec.ch>
Thu, 26 Mar 2015 07:34:00 +0000 (08:34 +0100)
committerMartin Willi <martin@revosec.ch>
Wed, 15 Apr 2015 09:35:27 +0000 (11:35 +0200)
CBC decryption can be parallelized, and we do so by queueing instructions
to the processor pipeline. While we have enough registers for 128-bit
decryption, the register count is insufficient to hold all variables with
larger key sizes. Nonetheless is 4-way parallelism faster, roughly by ~8%.

src/libstrongswan/plugins/aesni/aesni_cbc.c

index 6fba6d1..cf18faf 100644 (file)
 #include "aesni_cbc.h"
 #include "aesni_key.h"
 
+/**
+ * Pipeline parallelism we use for CBC decryption
+ */
+#define CBC_DECRYPT_PARALLELISM 4
+
 typedef struct private_aesni_cbc_t private_aesni_cbc_t;
 
 /**
@@ -113,8 +118,10 @@ static void decrypt_cbc128(aesni_key_t *key, u_int blocks, u_char *in,
                                                   u_char *iv, u_char *out)
 {
        __m128i k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10;
-       __m128i t, fb, last, *bi, *bo;
-       int i;
+       __m128i last, *bi, *bo;
+       __m128i t1, t2, t3, t4;
+       __m128i f1, f2, f3, f4;
+       u_int i, pblocks;
 
        k0 = key->schedule[0];
        k1 = key->schedule[1];
@@ -130,27 +137,98 @@ static void decrypt_cbc128(aesni_key_t *key, u_int blocks, u_char *in,
 
        bi = (__m128i*)in;
        bo = (__m128i*)out;
+       pblocks = blocks - (blocks % CBC_DECRYPT_PARALLELISM);
 
-       fb = _mm_loadu_si128((__m128i*)iv);
-       for (i = 0; i < blocks; i++)
+       f1 = _mm_loadu_si128((__m128i*)iv);
+
+       for (i = 0; i < pblocks; i += CBC_DECRYPT_PARALLELISM)
+       {
+               t1 = _mm_loadu_si128(bi + i + 0);
+               t2 = _mm_loadu_si128(bi + i + 1);
+               t3 = _mm_loadu_si128(bi + i + 2);
+               t4 = _mm_loadu_si128(bi + i + 3);
+
+               f2 = t1;
+               f3 = t2;
+               f4 = t3;
+               last = t4;
+
+               t1 = _mm_xor_si128(t1, k0);
+               t2 = _mm_xor_si128(t2, k0);
+               t3 = _mm_xor_si128(t3, k0);
+               t4 = _mm_xor_si128(t4, k0);
+
+               t1 = _mm_aesdec_si128(t1, k1);
+               t2 = _mm_aesdec_si128(t2, k1);
+               t3 = _mm_aesdec_si128(t3, k1);
+               t4 = _mm_aesdec_si128(t4, k1);
+               t1 = _mm_aesdec_si128(t1, k2);
+               t2 = _mm_aesdec_si128(t2, k2);
+               t3 = _mm_aesdec_si128(t3, k2);
+               t4 = _mm_aesdec_si128(t4, k2);
+               t1 = _mm_aesdec_si128(t1, k3);
+               t2 = _mm_aesdec_si128(t2, k3);
+               t3 = _mm_aesdec_si128(t3, k3);
+               t4 = _mm_aesdec_si128(t4, k3);
+               t1 = _mm_aesdec_si128(t1, k4);
+               t2 = _mm_aesdec_si128(t2, k4);
+               t3 = _mm_aesdec_si128(t3, k4);
+               t4 = _mm_aesdec_si128(t4, k4);
+               t1 = _mm_aesdec_si128(t1, k5);
+               t2 = _mm_aesdec_si128(t2, k5);
+               t3 = _mm_aesdec_si128(t3, k5);
+               t4 = _mm_aesdec_si128(t4, k5);
+               t1 = _mm_aesdec_si128(t1, k6);
+               t2 = _mm_aesdec_si128(t2, k6);
+               t3 = _mm_aesdec_si128(t3, k6);
+               t4 = _mm_aesdec_si128(t4, k6);
+               t1 = _mm_aesdec_si128(t1, k7);
+               t2 = _mm_aesdec_si128(t2, k7);
+               t3 = _mm_aesdec_si128(t3, k7);
+               t4 = _mm_aesdec_si128(t4, k7);
+               t1 = _mm_aesdec_si128(t1, k8);
+               t2 = _mm_aesdec_si128(t2, k8);
+               t3 = _mm_aesdec_si128(t3, k8);
+               t4 = _mm_aesdec_si128(t4, k8);
+               t1 = _mm_aesdec_si128(t1, k9);
+               t2 = _mm_aesdec_si128(t2, k9);
+               t3 = _mm_aesdec_si128(t3, k9);
+               t4 = _mm_aesdec_si128(t4, k9);
+
+               t1 = _mm_aesdeclast_si128(t1, k10);
+               t2 = _mm_aesdeclast_si128(t2, k10);
+               t3 = _mm_aesdeclast_si128(t3, k10);
+               t4 = _mm_aesdeclast_si128(t4, k10);
+               t1 = _mm_xor_si128(t1, f1);
+               t2 = _mm_xor_si128(t2, f2);
+               t3 = _mm_xor_si128(t3, f3);
+               t4 = _mm_xor_si128(t4, f4);
+               _mm_storeu_si128(bo + i + 0, t1);
+               _mm_storeu_si128(bo + i + 1, t2);
+               _mm_storeu_si128(bo + i + 2, t3);
+               _mm_storeu_si128(bo + i + 3, t4);
+               f1 = last;
+       }
+
+       for (i = pblocks; i < blocks; i++)
        {
                last = _mm_loadu_si128(bi + i);
-               t = _mm_xor_si128(last, k0);
-
-               t = _mm_aesdec_si128(t, k1);
-               t = _mm_aesdec_si128(t, k2);
-               t = _mm_aesdec_si128(t, k3);
-               t = _mm_aesdec_si128(t, k4);
-               t = _mm_aesdec_si128(t, k5);
-               t = _mm_aesdec_si128(t, k6);
-               t = _mm_aesdec_si128(t, k7);
-               t = _mm_aesdec_si128(t, k8);
-               t = _mm_aesdec_si128(t, k9);
-
-               t = _mm_aesdeclast_si128(t, k10);
-               t = _mm_xor_si128(t, fb);
-               _mm_storeu_si128(bo + i, t);
-               fb = last;
+               t1 = _mm_xor_si128(last, k0);
+
+               t1 = _mm_aesdec_si128(t1, k1);
+               t1 = _mm_aesdec_si128(t1, k2);
+               t1 = _mm_aesdec_si128(t1, k3);
+               t1 = _mm_aesdec_si128(t1, k4);
+               t1 = _mm_aesdec_si128(t1, k5);
+               t1 = _mm_aesdec_si128(t1, k6);
+               t1 = _mm_aesdec_si128(t1, k7);
+               t1 = _mm_aesdec_si128(t1, k8);
+               t1 = _mm_aesdec_si128(t1, k9);
+
+               t1 = _mm_aesdeclast_si128(t1, k10);
+               t1 = _mm_xor_si128(t1, f1);
+               _mm_storeu_si128(bo + i, t1);
+               f1 = last;
        }
 }
 
@@ -212,8 +290,10 @@ static void decrypt_cbc192(aesni_key_t *key, u_int blocks, u_char *in,
                                                   u_char *iv, u_char *out)
 {
        __m128i k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12;
-       __m128i t, fb, last, *bi, *bo;
-       int i;
+       __m128i last, *bi, *bo;
+       __m128i t1, t2, t3, t4;
+       __m128i f1, f2, f3, f4;
+       u_int i, pblocks;
 
        k0 = key->schedule[0];
        k1 = key->schedule[1];
@@ -231,29 +311,108 @@ static void decrypt_cbc192(aesni_key_t *key, u_int blocks, u_char *in,
 
        bi = (__m128i*)in;
        bo = (__m128i*)out;
+       pblocks = blocks - (blocks % CBC_DECRYPT_PARALLELISM);
 
-       fb = _mm_loadu_si128((__m128i*)iv);
-       for (i = 0; i < blocks; i++)
+       f1 = _mm_loadu_si128((__m128i*)iv);
+
+       for (i = 0; i < pblocks; i += CBC_DECRYPT_PARALLELISM)
+       {
+               t1 = _mm_loadu_si128(bi + i + 0);
+               t2 = _mm_loadu_si128(bi + i + 1);
+               t3 = _mm_loadu_si128(bi + i + 2);
+               t4 = _mm_loadu_si128(bi + i + 3);
+
+               f2 = t1;
+               f3 = t2;
+               f4 = t3;
+               last = t4;
+
+               t1 = _mm_xor_si128(t1, k0);
+               t2 = _mm_xor_si128(t2, k0);
+               t3 = _mm_xor_si128(t3, k0);
+               t4 = _mm_xor_si128(t4, k0);
+
+               t1 = _mm_aesdec_si128(t1, k1);
+               t2 = _mm_aesdec_si128(t2, k1);
+               t3 = _mm_aesdec_si128(t3, k1);
+               t4 = _mm_aesdec_si128(t4, k1);
+               t1 = _mm_aesdec_si128(t1, k2);
+               t2 = _mm_aesdec_si128(t2, k2);
+               t3 = _mm_aesdec_si128(t3, k2);
+               t4 = _mm_aesdec_si128(t4, k2);
+               t1 = _mm_aesdec_si128(t1, k3);
+               t2 = _mm_aesdec_si128(t2, k3);
+               t3 = _mm_aesdec_si128(t3, k3);
+               t4 = _mm_aesdec_si128(t4, k3);
+               t1 = _mm_aesdec_si128(t1, k4);
+               t2 = _mm_aesdec_si128(t2, k4);
+               t3 = _mm_aesdec_si128(t3, k4);
+               t4 = _mm_aesdec_si128(t4, k4);
+               t1 = _mm_aesdec_si128(t1, k5);
+               t2 = _mm_aesdec_si128(t2, k5);
+               t3 = _mm_aesdec_si128(t3, k5);
+               t4 = _mm_aesdec_si128(t4, k5);
+               t1 = _mm_aesdec_si128(t1, k6);
+               t2 = _mm_aesdec_si128(t2, k6);
+               t3 = _mm_aesdec_si128(t3, k6);
+               t4 = _mm_aesdec_si128(t4, k6);
+               t1 = _mm_aesdec_si128(t1, k7);
+               t2 = _mm_aesdec_si128(t2, k7);
+               t3 = _mm_aesdec_si128(t3, k7);
+               t4 = _mm_aesdec_si128(t4, k7);
+               t1 = _mm_aesdec_si128(t1, k8);
+               t2 = _mm_aesdec_si128(t2, k8);
+               t3 = _mm_aesdec_si128(t3, k8);
+               t4 = _mm_aesdec_si128(t4, k8);
+               t1 = _mm_aesdec_si128(t1, k9);
+               t2 = _mm_aesdec_si128(t2, k9);
+               t3 = _mm_aesdec_si128(t3, k9);
+               t4 = _mm_aesdec_si128(t4, k9);
+               t1 = _mm_aesdec_si128(t1, k10);
+               t2 = _mm_aesdec_si128(t2, k10);
+               t3 = _mm_aesdec_si128(t3, k10);
+               t4 = _mm_aesdec_si128(t4, k10);
+               t1 = _mm_aesdec_si128(t1, k11);
+               t2 = _mm_aesdec_si128(t2, k11);
+               t3 = _mm_aesdec_si128(t3, k11);
+               t4 = _mm_aesdec_si128(t4, k11);
+
+               t1 = _mm_aesdeclast_si128(t1, k12);
+               t2 = _mm_aesdeclast_si128(t2, k12);
+               t3 = _mm_aesdeclast_si128(t3, k12);
+               t4 = _mm_aesdeclast_si128(t4, k12);
+               t1 = _mm_xor_si128(t1, f1);
+               t2 = _mm_xor_si128(t2, f2);
+               t3 = _mm_xor_si128(t3, f3);
+               t4 = _mm_xor_si128(t4, f4);
+               _mm_storeu_si128(bo + i + 0, t1);
+               _mm_storeu_si128(bo + i + 1, t2);
+               _mm_storeu_si128(bo + i + 2, t3);
+               _mm_storeu_si128(bo + i + 3, t4);
+               f1 = last;
+       }
+
+       for (i = pblocks; i < blocks; i++)
        {
                last = _mm_loadu_si128(bi + i);
-               t = _mm_xor_si128(last, k0);
-
-               t = _mm_aesdec_si128(t, k1);
-               t = _mm_aesdec_si128(t, k2);
-               t = _mm_aesdec_si128(t, k3);
-               t = _mm_aesdec_si128(t, k4);
-               t = _mm_aesdec_si128(t, k5);
-               t = _mm_aesdec_si128(t, k6);
-               t = _mm_aesdec_si128(t, k7);
-               t = _mm_aesdec_si128(t, k8);
-               t = _mm_aesdec_si128(t, k9);
-               t = _mm_aesdec_si128(t, k10);
-               t = _mm_aesdec_si128(t, k11);
-
-               t = _mm_aesdeclast_si128(t, k12);
-               t = _mm_xor_si128(t, fb);
-               _mm_storeu_si128(bo + i, t);
-               fb = last;
+               t1 = _mm_xor_si128(last, k0);
+
+               t1 = _mm_aesdec_si128(t1, k1);
+               t1 = _mm_aesdec_si128(t1, k2);
+               t1 = _mm_aesdec_si128(t1, k3);
+               t1 = _mm_aesdec_si128(t1, k4);
+               t1 = _mm_aesdec_si128(t1, k5);
+               t1 = _mm_aesdec_si128(t1, k6);
+               t1 = _mm_aesdec_si128(t1, k7);
+               t1 = _mm_aesdec_si128(t1, k8);
+               t1 = _mm_aesdec_si128(t1, k9);
+               t1 = _mm_aesdec_si128(t1, k10);
+               t1 = _mm_aesdec_si128(t1, k11);
+
+               t1 = _mm_aesdeclast_si128(t1, k12);
+               t1 = _mm_xor_si128(t1, f1);
+               _mm_storeu_si128(bo + i, t1);
+               f1 = last;
        }
 }
 
@@ -319,8 +478,10 @@ static void decrypt_cbc256(aesni_key_t *key, u_int blocks, u_char *in,
                                                   u_char *iv, u_char *out)
 {
        __m128i k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14;
-       __m128i t, fb, last, *bi, *bo;
-       int i;
+       __m128i last, *bi, *bo;
+       __m128i t1, t2, t3, t4;
+       __m128i f1, f2, f3, f4;
+       u_int i, pblocks;
 
        k0 = key->schedule[0];
        k1 = key->schedule[1];
@@ -340,31 +501,118 @@ static void decrypt_cbc256(aesni_key_t *key, u_int blocks, u_char *in,
 
        bi = (__m128i*)in;
        bo = (__m128i*)out;
+       pblocks = blocks - (blocks % CBC_DECRYPT_PARALLELISM);
 
-       fb = _mm_loadu_si128((__m128i*)iv);
-       for (i = 0; i < blocks; i++)
+       f1 = _mm_loadu_si128((__m128i*)iv);
+
+       for (i = 0; i < pblocks; i += CBC_DECRYPT_PARALLELISM)
+       {
+               t1 = _mm_loadu_si128(bi + i + 0);
+               t2 = _mm_loadu_si128(bi + i + 1);
+               t3 = _mm_loadu_si128(bi + i + 2);
+               t4 = _mm_loadu_si128(bi + i + 3);
+
+               f2 = t1;
+               f3 = t2;
+               f4 = t3;
+               last = t4;
+
+               t1 = _mm_xor_si128(t1, k0);
+               t2 = _mm_xor_si128(t2, k0);
+               t3 = _mm_xor_si128(t3, k0);
+               t4 = _mm_xor_si128(t4, k0);
+
+               t1 = _mm_aesdec_si128(t1, k1);
+               t2 = _mm_aesdec_si128(t2, k1);
+               t3 = _mm_aesdec_si128(t3, k1);
+               t4 = _mm_aesdec_si128(t4, k1);
+               t1 = _mm_aesdec_si128(t1, k2);
+               t2 = _mm_aesdec_si128(t2, k2);
+               t3 = _mm_aesdec_si128(t3, k2);
+               t4 = _mm_aesdec_si128(t4, k2);
+               t1 = _mm_aesdec_si128(t1, k3);
+               t2 = _mm_aesdec_si128(t2, k3);
+               t3 = _mm_aesdec_si128(t3, k3);
+               t4 = _mm_aesdec_si128(t4, k3);
+               t1 = _mm_aesdec_si128(t1, k4);
+               t2 = _mm_aesdec_si128(t2, k4);
+               t3 = _mm_aesdec_si128(t3, k4);
+               t4 = _mm_aesdec_si128(t4, k4);
+               t1 = _mm_aesdec_si128(t1, k5);
+               t2 = _mm_aesdec_si128(t2, k5);
+               t3 = _mm_aesdec_si128(t3, k5);
+               t4 = _mm_aesdec_si128(t4, k5);
+               t1 = _mm_aesdec_si128(t1, k6);
+               t2 = _mm_aesdec_si128(t2, k6);
+               t3 = _mm_aesdec_si128(t3, k6);
+               t4 = _mm_aesdec_si128(t4, k6);
+               t1 = _mm_aesdec_si128(t1, k7);
+               t2 = _mm_aesdec_si128(t2, k7);
+               t3 = _mm_aesdec_si128(t3, k7);
+               t4 = _mm_aesdec_si128(t4, k7);
+               t1 = _mm_aesdec_si128(t1, k8);
+               t2 = _mm_aesdec_si128(t2, k8);
+               t3 = _mm_aesdec_si128(t3, k8);
+               t4 = _mm_aesdec_si128(t4, k8);
+               t1 = _mm_aesdec_si128(t1, k9);
+               t2 = _mm_aesdec_si128(t2, k9);
+               t3 = _mm_aesdec_si128(t3, k9);
+               t4 = _mm_aesdec_si128(t4, k9);
+               t1 = _mm_aesdec_si128(t1, k10);
+               t2 = _mm_aesdec_si128(t2, k10);
+               t3 = _mm_aesdec_si128(t3, k10);
+               t4 = _mm_aesdec_si128(t4, k10);
+               t1 = _mm_aesdec_si128(t1, k11);
+               t2 = _mm_aesdec_si128(t2, k11);
+               t3 = _mm_aesdec_si128(t3, k11);
+               t4 = _mm_aesdec_si128(t4, k11);
+               t1 = _mm_aesdec_si128(t1, k12);
+               t2 = _mm_aesdec_si128(t2, k12);
+               t3 = _mm_aesdec_si128(t3, k12);
+               t4 = _mm_aesdec_si128(t4, k12);
+               t1 = _mm_aesdec_si128(t1, k13);
+               t2 = _mm_aesdec_si128(t2, k13);
+               t3 = _mm_aesdec_si128(t3, k13);
+               t4 = _mm_aesdec_si128(t4, k13);
+
+               t1 = _mm_aesdeclast_si128(t1, k14);
+               t2 = _mm_aesdeclast_si128(t2, k14);
+               t3 = _mm_aesdeclast_si128(t3, k14);
+               t4 = _mm_aesdeclast_si128(t4, k14);
+               t1 = _mm_xor_si128(t1, f1);
+               t2 = _mm_xor_si128(t2, f2);
+               t3 = _mm_xor_si128(t3, f3);
+               t4 = _mm_xor_si128(t4, f4);
+               _mm_storeu_si128(bo + i + 0, t1);
+               _mm_storeu_si128(bo + i + 1, t2);
+               _mm_storeu_si128(bo + i + 2, t3);
+               _mm_storeu_si128(bo + i + 3, t4);
+               f1 = last;
+       }
+
+       for (i = pblocks; i < blocks; i++)
        {
                last = _mm_loadu_si128(bi + i);
-               t = _mm_xor_si128(last, k0);
-
-               t = _mm_aesdec_si128(t, k1);
-               t = _mm_aesdec_si128(t, k2);
-               t = _mm_aesdec_si128(t, k3);
-               t = _mm_aesdec_si128(t, k4);
-               t = _mm_aesdec_si128(t, k5);
-               t = _mm_aesdec_si128(t, k6);
-               t = _mm_aesdec_si128(t, k7);
-               t = _mm_aesdec_si128(t, k8);
-               t = _mm_aesdec_si128(t, k9);
-               t = _mm_aesdec_si128(t, k10);
-               t = _mm_aesdec_si128(t, k11);
-               t = _mm_aesdec_si128(t, k12);
-               t = _mm_aesdec_si128(t, k13);
-
-               t = _mm_aesdeclast_si128(t, k14);
-               t = _mm_xor_si128(t, fb);
-               _mm_storeu_si128(bo + i, t);
-               fb = last;
+               t1 = _mm_xor_si128(last, k0);
+
+               t1 = _mm_aesdec_si128(t1, k1);
+               t1 = _mm_aesdec_si128(t1, k2);
+               t1 = _mm_aesdec_si128(t1, k3);
+               t1 = _mm_aesdec_si128(t1, k4);
+               t1 = _mm_aesdec_si128(t1, k5);
+               t1 = _mm_aesdec_si128(t1, k6);
+               t1 = _mm_aesdec_si128(t1, k7);
+               t1 = _mm_aesdec_si128(t1, k8);
+               t1 = _mm_aesdec_si128(t1, k9);
+               t1 = _mm_aesdec_si128(t1, k10);
+               t1 = _mm_aesdec_si128(t1, k11);
+               t1 = _mm_aesdec_si128(t1, k12);
+               t1 = _mm_aesdec_si128(t1, k13);
+
+               t1 = _mm_aesdeclast_si128(t1, k14);
+               t1 = _mm_xor_si128(t1, f1);
+               _mm_storeu_si128(bo + i, t1);
+               f1 = last;
        }
 }