chapoly: Process four ChaCha20 blocks in parallel in SSSE3 driver
authorMartin Willi <martin@revosec.ch>
Sun, 5 Apr 2015 19:50:03 +0000 (21:50 +0200)
committerMartin Willi <martin@revosec.ch>
Sun, 12 Jul 2015 11:25:36 +0000 (13:25 +0200)
As we don't have to shuffle the state in each ChaCha round, overall performance
for ChaCha20-Poly1305 increases by ~40%.

src/libstrongswan/plugins/chapoly/chapoly_drv_ssse3.c

index ab0c69e..a33bbf1 100644 (file)
@@ -94,6 +94,23 @@ static inline u_long and(u_long v, u_long mask)
 }
 
 /**
+ * r = shuffle(a ^ b, s)
+ */
+static inline __m128i sfflxor32(__m128i a, __m128i b, __m128i s)
+{
+       return _mm_shuffle_epi8(_mm_xor_si128(a, b), s);
+}
+
+/**
+ * r = rotl32(a ^ b, r)
+ */
+static inline __m128i rotlxor32(__m128i a, __m128i b, u_char r)
+{
+       a = _mm_xor_si128(a, b);
+       return _mm_or_si128(_mm_slli_epi32(a, r), _mm_srli_epi32(a, 32 - r));
+}
+
+/**
  * XOR a Chacha20 keystream block into data, increment counter
  */
 static void chacha_block_xor(private_chapoly_drv_ssse3_t *this, void *data)
@@ -112,40 +129,32 @@ static void chacha_block_xor(private_chapoly_drv_ssse3_t *this, void *data)
        for (i = 0 ; i < CHACHA_DOUBLEROUNDS; i++)
        {
                x0 = _mm_add_epi32(x0, x1);
-               x3 = _mm_xor_si128(x3, x0);
-               x3 = _mm_shuffle_epi8(x3, r16);
+               x3 = sfflxor32(x3, x0, r16);
 
                x2 = _mm_add_epi32(x2, x3);
-               x1 = _mm_xor_si128(x1, x2);
-               x1 = _mm_xor_si128(_mm_slli_epi32(x1, 12), _mm_srli_epi32(x1, 20));
+               x1 = rotlxor32(x1, x2, 12);
 
                x0 = _mm_add_epi32(x0, x1);
-               x3 = _mm_xor_si128(x3, x0);
-               x3 = _mm_shuffle_epi8(x3, r8);
+               x3 = sfflxor32(x3, x0, r8);
 
                x2 = _mm_add_epi32(x2, x3);
-               x1 = _mm_xor_si128(x1, x2);
-               x1 = _mm_xor_si128(_mm_slli_epi32(x1, 7), _mm_srli_epi32(x1, 25));
+               x1 = rotlxor32(x1, x2, 7);
 
                x1 = _mm_shuffle_epi32(x1, _MM_SHUFFLE(0, 3, 2, 1));
                x2 = _mm_shuffle_epi32(x2, _MM_SHUFFLE(1, 0, 3, 2));
                x3 = _mm_shuffle_epi32(x3, _MM_SHUFFLE(2, 1, 0, 3));
 
                x0 = _mm_add_epi32(x0, x1);
-               x3 = _mm_xor_si128(x3, x0);
-               x3 = _mm_shuffle_epi8(x3, r16);
+               x3 = sfflxor32(x3, x0, r16);
 
                x2 = _mm_add_epi32(x2, x3);
-               x1 = _mm_xor_si128(x1, x2);
-               x1 = _mm_xor_si128(_mm_slli_epi32(x1, 12), _mm_srli_epi32(x1, 20));
+               x1 = rotlxor32(x1, x2, 12);
 
                x0 = _mm_add_epi32(x0, x1);
-               x3 = _mm_xor_si128(x3, x0);
-               x3 = _mm_shuffle_epi8(x3, r8);
+               x3 = sfflxor32(x3, x0, r8);
 
                x2 = _mm_add_epi32(x2, x3);
-               x1 = _mm_xor_si128(x1, x2);
-               x1 = _mm_xor_si128(_mm_slli_epi32(x1, 7), _mm_srli_epi32(x1, 25));
+               x1 = rotlxor32(x1, x2, 7);
 
                x1 = _mm_shuffle_epi32(x1, _MM_SHUFFLE(2, 1, 0, 3));
                x2 = _mm_shuffle_epi32(x2, _MM_SHUFFLE(1, 0, 3, 2));
@@ -168,6 +177,174 @@ static void chacha_block_xor(private_chapoly_drv_ssse3_t *this, void *data)
        this->m[3] = _mm_add_epi32(this->m[3], _mm_set_epi32(0, 0, 0, 1));
 }
 
+/**
+ * XOR four Chacha20 keystream blocks into data, increment counter
+ */
+static void chacha_4block_xor(private_chapoly_drv_ssse3_t *this, void *data)
+{
+       __m128i x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, xa, xb, xc, xd, xe, xf;
+       __m128i r8, r16, ctrinc, t, *out = data;
+       u_int32_t *m = (u_int32_t*)this->m;
+       u_int i;
+
+       r8  = _mm_set_epi8(14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3);
+       r16 = _mm_set_epi8(13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2);
+       ctrinc = _mm_set_epi32(3, 2, 1, 0);
+
+       x0 = _mm_set1_epi32(m[ 0]);
+       x1 = _mm_set1_epi32(m[ 1]);
+       x2 = _mm_set1_epi32(m[ 2]);
+       x3 = _mm_set1_epi32(m[ 3]);
+       x4 = _mm_set1_epi32(m[ 4]);
+       x5 = _mm_set1_epi32(m[ 5]);
+       x6 = _mm_set1_epi32(m[ 6]);
+       x7 = _mm_set1_epi32(m[ 7]);
+       x8 = _mm_set1_epi32(m[ 8]);
+       x9 = _mm_set1_epi32(m[ 9]);
+       xa = _mm_set1_epi32(m[10]);
+       xb = _mm_set1_epi32(m[11]);
+       xc = _mm_set1_epi32(m[12]);
+       xd = _mm_set1_epi32(m[13]);
+       xe = _mm_set1_epi32(m[14]);
+       xf = _mm_set1_epi32(m[15]);
+
+       xc = _mm_add_epi32(xc, ctrinc);
+
+       for (i = 0 ; i < CHACHA_DOUBLEROUNDS; i++)
+       {
+               x0 = _mm_add_epi32(x0, x4); xc = sfflxor32(xc, x0, r16);
+               x1 = _mm_add_epi32(x1, x5); xd = sfflxor32(xd, x1, r16);
+               x2 = _mm_add_epi32(x2, x6); xe = sfflxor32(xe, x2, r16);
+               x3 = _mm_add_epi32(x3, x7); xf = sfflxor32(xf, x3, r16);
+
+               x8 = _mm_add_epi32(x8, xc); x4 = rotlxor32(x4, x8, 12);
+               x9 = _mm_add_epi32(x9, xd); x5 = rotlxor32(x5, x9, 12);
+               xa = _mm_add_epi32(xa, xe); x6 = rotlxor32(x6, xa, 12);
+               xb = _mm_add_epi32(xb, xf); x7 = rotlxor32(x7, xb, 12);
+
+               x0 = _mm_add_epi32(x0, x4); xc = sfflxor32(xc, x0, r8);
+               x1 = _mm_add_epi32(x1, x5); xd = sfflxor32(xd, x1, r8);
+               x2 = _mm_add_epi32(x2, x6); xe = sfflxor32(xe, x2, r8);
+               x3 = _mm_add_epi32(x3, x7); xf = sfflxor32(xf, x3, r8);
+
+               x8 = _mm_add_epi32(x8, xc); x4 = rotlxor32(x4, x8, 7);
+               x9 = _mm_add_epi32(x9, xd); x5 = rotlxor32(x5, x9, 7);
+               xa = _mm_add_epi32(xa, xe); x6 = rotlxor32(x6, xa, 7);
+               xb = _mm_add_epi32(xb, xf); x7 = rotlxor32(x7, xb, 7);
+
+               x0 = _mm_add_epi32(x0, x5); xf = sfflxor32(xf, x0, r16);
+               x1 = _mm_add_epi32(x1, x6); xc = sfflxor32(xc, x1, r16);
+               x2 = _mm_add_epi32(x2, x7); xd = sfflxor32(xd, x2, r16);
+               x3 = _mm_add_epi32(x3, x4); xe = sfflxor32(xe, x3, r16);
+
+               xa = _mm_add_epi32(xa, xf); x5 = rotlxor32(x5, xa, 12);
+               xb = _mm_add_epi32(xb, xc); x6 = rotlxor32(x6, xb, 12);
+               x8 = _mm_add_epi32(x8, xd); x7 = rotlxor32(x7, x8, 12);
+               x9 = _mm_add_epi32(x9, xe); x4 = rotlxor32(x4, x9, 12);
+
+               x0 = _mm_add_epi32(x0, x5); xf = sfflxor32(xf, x0, r8);
+               x1 = _mm_add_epi32(x1, x6); xc = sfflxor32(xc, x1, r8);
+               x2 = _mm_add_epi32(x2, x7); xd = sfflxor32(xd, x2, r8);
+               x3 = _mm_add_epi32(x3, x4); xe = sfflxor32(xe, x3, r8);
+
+               xa = _mm_add_epi32(xa, xf); x5 = rotlxor32(x5, xa, 7);
+               xb = _mm_add_epi32(xb, xc); x6 = rotlxor32(x6, xb, 7);
+               x8 = _mm_add_epi32(x8, xd); x7 = rotlxor32(x7, x8, 7);
+               x9 = _mm_add_epi32(x9, xe); x4 = rotlxor32(x4, x9, 7);
+       }
+
+       x0 = _mm_add_epi32(x0, _mm_set1_epi32(m[ 0]));
+       x1 = _mm_add_epi32(x1, _mm_set1_epi32(m[ 1]));
+       x2 = _mm_add_epi32(x2, _mm_set1_epi32(m[ 2]));
+       x3 = _mm_add_epi32(x3, _mm_set1_epi32(m[ 3]));
+       x4 = _mm_add_epi32(x4, _mm_set1_epi32(m[ 4]));
+       x5 = _mm_add_epi32(x5, _mm_set1_epi32(m[ 5]));
+       x6 = _mm_add_epi32(x6, _mm_set1_epi32(m[ 6]));
+       x7 = _mm_add_epi32(x7, _mm_set1_epi32(m[ 7]));
+       x8 = _mm_add_epi32(x8, _mm_set1_epi32(m[ 8]));
+       x9 = _mm_add_epi32(x9, _mm_set1_epi32(m[ 9]));
+       xa = _mm_add_epi32(xa, _mm_set1_epi32(m[10]));
+       xb = _mm_add_epi32(xb, _mm_set1_epi32(m[11]));
+       xc = _mm_add_epi32(xc, _mm_set1_epi32(m[12]));
+       xd = _mm_add_epi32(xd, _mm_set1_epi32(m[13]));
+       xe = _mm_add_epi32(xe, _mm_set1_epi32(m[14]));
+       xf = _mm_add_epi32(xf, _mm_set1_epi32(m[15]));
+
+       xc = _mm_add_epi32(xc, ctrinc);
+
+       /* transpose state matrix by interleaving 32-, then 64-bit words */
+       t = x0; x0 = _mm_unpacklo_epi32(t, x1);
+                       x1 = _mm_unpackhi_epi32(t, x1);
+       t = x2; x2 = _mm_unpacklo_epi32(t, x3);
+                       x3 = _mm_unpackhi_epi32(t, x3);
+       t = x4; x4 = _mm_unpacklo_epi32(t, x5);
+                       x5 = _mm_unpackhi_epi32(t, x5);
+       t = x6; x6 = _mm_unpacklo_epi32(t, x7);
+                       x7 = _mm_unpackhi_epi32(t, x7);
+       t = x8; x8 = _mm_unpacklo_epi32(t, x9);
+                       x9 = _mm_unpackhi_epi32(t, x9);
+       t = xa; xa = _mm_unpacklo_epi32(t, xb);
+                       xb = _mm_unpackhi_epi32(t, xb);
+       t = xc; xc = _mm_unpacklo_epi32(t, xd);
+                       xd = _mm_unpackhi_epi32(t, xd);
+       t = xe; xe = _mm_unpacklo_epi32(t, xf);
+                       xf = _mm_unpackhi_epi32(t, xf);
+
+       t = x0; x0 = _mm_unpacklo_epi64(t, x2);
+                       x2 = _mm_unpackhi_epi64(t, x2);
+       t = x1; x1 = _mm_unpacklo_epi64(t, x3);
+                       x3 = _mm_unpackhi_epi64(t, x3);
+       t = x4; x4 = _mm_unpacklo_epi64(t, x6);
+                       x6 = _mm_unpackhi_epi64(t, x6);
+       t = x5; x5 = _mm_unpacklo_epi64(t, x7);
+                       x7 = _mm_unpackhi_epi64(t, x7);
+       t = x8; x8 = _mm_unpacklo_epi64(t, xa);
+                       xa = _mm_unpackhi_epi64(t, xa);
+       t = x9; x9 = _mm_unpacklo_epi64(t, xb);
+                       xb = _mm_unpackhi_epi64(t, xb);
+       t = xc; xc = _mm_unpacklo_epi64(t, xe);
+                       xe = _mm_unpackhi_epi64(t, xe);
+       t = xd; xd = _mm_unpacklo_epi64(t, xf);
+                       xf = _mm_unpackhi_epi64(t, xf);
+
+       x0 = _mm_xor_si128(_mm_loadu_si128(out +  0), x0);
+       x1 = _mm_xor_si128(_mm_loadu_si128(out +  8), x1);
+       x2 = _mm_xor_si128(_mm_loadu_si128(out +  4), x2);
+       x3 = _mm_xor_si128(_mm_loadu_si128(out + 12), x3);
+       x4 = _mm_xor_si128(_mm_loadu_si128(out +  1), x4);
+       x5 = _mm_xor_si128(_mm_loadu_si128(out +  9), x5);
+       x6 = _mm_xor_si128(_mm_loadu_si128(out +  5), x6);
+       x7 = _mm_xor_si128(_mm_loadu_si128(out + 13), x7);
+       x8 = _mm_xor_si128(_mm_loadu_si128(out +  2), x8);
+       x9 = _mm_xor_si128(_mm_loadu_si128(out + 10), x9);
+       xa = _mm_xor_si128(_mm_loadu_si128(out +  6), xa);
+       xb = _mm_xor_si128(_mm_loadu_si128(out + 14), xb);
+       xc = _mm_xor_si128(_mm_loadu_si128(out +  3), xc);
+       xd = _mm_xor_si128(_mm_loadu_si128(out + 11), xd);
+       xe = _mm_xor_si128(_mm_loadu_si128(out +  7), xe);
+       xf = _mm_xor_si128(_mm_loadu_si128(out + 15), xf);
+
+       _mm_storeu_si128(out +  0, x0);
+       _mm_storeu_si128(out +  8, x1);
+       _mm_storeu_si128(out +  4, x2);
+       _mm_storeu_si128(out + 12, x3);
+       _mm_storeu_si128(out +  1, x4);
+       _mm_storeu_si128(out +  9, x5);
+       _mm_storeu_si128(out +  5, x6);
+       _mm_storeu_si128(out + 13, x7);
+       _mm_storeu_si128(out +  2, x8);
+       _mm_storeu_si128(out + 10, x9);
+       _mm_storeu_si128(out +  6, xa);
+       _mm_storeu_si128(out + 14, xb);
+       _mm_storeu_si128(out +  3, xc);
+       _mm_storeu_si128(out + 11, xd);
+       _mm_storeu_si128(out +  7, xe);
+       _mm_storeu_si128(out + 15, xf);
+
+       this->m[3] = _mm_add_epi32(this->m[3], _mm_set_epi32(0, 0, 0, 4));
+}
+
+
 METHOD(chapoly_drv_t, set_key, bool,
        private_chapoly_drv_ssse3_t *this, u_char *constant, u_char *key,
        u_char *salt)
@@ -342,6 +519,13 @@ METHOD(chapoly_drv_t, encrypt, bool,
 {
        u_int i;
 
+       while (blocks >= 4)
+       {
+               chacha_4block_xor(this, data);
+               poly(this, data, 16);
+               data += CHACHA_BLOCK_SIZE * 4;
+               blocks -= 4;
+       }
        for (i = 0; i < blocks; i++)
        {
                chacha_block_xor(this, data);
@@ -356,6 +540,13 @@ METHOD(chapoly_drv_t, decrypt, bool,
 {
        u_int i;
 
+       while (blocks >= 4)
+       {
+               poly(this, data, 16);
+               chacha_4block_xor(this, data);
+               data += CHACHA_BLOCK_SIZE * 4;
+               blocks -= 4;
+       }
        for (i = 0; i < blocks; i++)
        {
                poly(this, data, 4);