chapoly: Process two Poly1305 blocks in parallel in SSSE3 driver
authorMartin Willi <martin@revosec.ch>
Tue, 7 Apr 2015 09:28:51 +0000 (11:28 +0200)
committerMartin Willi <martin@revosec.ch>
Sun, 12 Jul 2015 11:25:50 +0000 (13:25 +0200)
By using a derived key r^2 we can improve performance, as we can do loop
unrolling and slightly better utilize SIMD instructions.

Overall ChaCha20-Poly1305 performance increases by ~12%.

Converting integers to/from our 5-word representation in SSE does not seem
to pay off, so we work on individual words.

src/libstrongswan/plugins/chapoly/chapoly_drv_ssse3.c

index a33bbf1..df88e7d 100644 (file)
@@ -48,6 +48,11 @@ struct private_chapoly_drv_ssse3_t {
        u_int32_t r[5];
 
        /**
+        * Poly1305 update key r^2
+        */
+       u_int32_t u[5];
+
+       /**
         * Poly1305 state
         */
        u_int32_t h[5];
@@ -344,7 +349,6 @@ static void chacha_4block_xor(private_chapoly_drv_ssse3_t *this, void *data)
        this->m[3] = _mm_add_epi32(this->m[3], _mm_set_epi32(0, 0, 0, 4));
 }
 
-
 METHOD(chapoly_drv_t, set_key, bool,
        private_chapoly_drv_ssse3_t *this, u_char *constant, u_char *key,
        u_char *salt)
@@ -357,6 +361,121 @@ METHOD(chapoly_drv_t, set_key, bool,
        return TRUE;
 }
 
+/**
+ * r[127:64] = h[95:64] * a, r[63:0] = h[31:0] * b
+ */
+static inline __m128i mul2(__m128i h, u_int32_t a, u_int32_t b)
+{
+       return _mm_mul_epu32(h, _mm_set_epi32(0, a, 0, b));
+}
+
+/**
+ * c = a[127:64] + a[63:0] + b[127:64] + b[63:0]
+ * z = x[127:64] + x[63:0] + y[127:64] + y[63:0]
+ */
+static inline void sum2(__m128i a, __m128i b, __m128i x, __m128i y,
+                                               u_int64_t *c, u_int64_t *z)
+{
+       __m128i r, s;
+
+       a = _mm_add_epi64(a, b);
+       x = _mm_add_epi64(x, y);
+       r = _mm_unpacklo_epi64(x, a);
+       s = _mm_unpackhi_epi64(x, a);
+       r = _mm_add_epi64(r, s);
+
+       _mm_storel_epi64((__m128i*)z, r);
+       _mm_storel_epi64((__m128i*)c, _mm_srli_si128(r, 8));
+}
+
+/**
+ * r = a[127:64] + b[127:64] + c[127:64] + d[127:64] + e[127:64]
+ *   + a[63:0]   + b[63:0]   + c[63:0]   + d[63:0]   + e[63:0]
+ */
+static inline u_int64_t sum5(__m128i a, __m128i b, __m128i c,
+                                                        __m128i d, __m128i e)
+{
+       u_int64_t r;
+
+       a = _mm_add_epi64(a, b);
+       c = _mm_add_epi64(c, d);
+       a = _mm_add_epi64(a, e);
+       a = _mm_add_epi64(a, c);
+
+       a = _mm_add_epi64(a, _mm_srli_si128(a, 8));
+       _mm_storel_epi64((__m128i*)&r, a);
+
+       return r;
+}
+
+/**
+ * Make second Poly1305 key u = r^2
+ */
+static void make_u(private_chapoly_drv_ssse3_t *this)
+{
+       __m128i r01, r23, r44, x0, x1, y0, y1, z0;
+       u_int32_t r0, r1, r2, r3, r4;
+       u_int32_t u0, u1, u2, u3, u4;
+       u_int32_t s1, s2, s3, s4;
+       u_int64_t d0, d1, d2, d3, d4;
+
+       r0 = this->r[0];
+       r1 = this->r[1];
+       r2 = this->r[2];
+       r3 = this->r[3];
+       r4 = this->r[4];
+
+       s1 = r1 * 5;
+       s2 = r2 * 5;
+       s3 = r3 * 5;
+       s4 = r4 * 5;
+
+       r01 = _mm_set_epi32(0, r0, 0, r1);
+       r23 = _mm_set_epi32(0, r2, 0, r3);
+       r44 = _mm_set_epi32(0, r4, 0, r4);
+
+       /* u = r^2 */
+       x0 = mul2(r01, r0, s4);
+       x1 = mul2(r01, r1, r0);
+       y0 = mul2(r23, s3, s2);
+       y1 = mul2(r23, s4, s3);
+       z0 = mul2(r44, s1, s2);
+       y0 = _mm_add_epi64(y0, _mm_srli_si128(z0, 8));
+       y1 = _mm_add_epi64(y1, _mm_slli_si128(z0, 8));
+       sum2(x0, y0, x1, y1, &d0, &d1);
+
+       x0 = mul2(r01, r2, r1);
+       x1 = mul2(r01, r3, r2);
+       y0 = mul2(r23, r0, s4);
+       y1 = mul2(r23, r1, r0);
+       z0 = mul2(r44, s3, s4);
+       y0 = _mm_add_epi64(y0, _mm_srli_si128(z0, 8));
+       y1 = _mm_add_epi64(y1, _mm_slli_si128(z0, 8));
+       sum2(x0, y0, x1, y1, &d2, &d3);
+
+       x0 = mul2(r01, r4, r3);
+       y0 = mul2(r23, r2, r1);
+       z0 = mul2(r44, r0, 0);
+       y0 = _mm_add_epi64(y0, z0);
+       x0 = _mm_add_epi64(x0, y0);
+       x0 = _mm_add_epi64(x0, _mm_srli_si128(x0, 8));
+       _mm_storel_epi64((__m128i*)&d4, x0);
+
+       /* (partial) r %= p */
+       d1 += sr(d0, 26);     u0 = and(d0, 0x3ffffff);
+       d2 += sr(d1, 26);     u1 = and(d1, 0x3ffffff);
+       d3 += sr(d2, 26);     u2 = and(d2, 0x3ffffff);
+       d4 += sr(d3, 26);     u3 = and(d3, 0x3ffffff);
+       u0 += sr(d4, 26) * 5; u4 = and(d4, 0x3ffffff);
+       u1 += u0 >> 26;       u0 &= 0x3ffffff;
+
+       this->u[0] = u0;
+       this->u[1] = u1;
+       this->u[2] = u2;
+       this->u[3] = u3;
+       this->u[4] = u4;
+}
+
 METHOD(chapoly_drv_t, init, bool,
        private_chapoly_drv_ssse3_t *this, u_char *iv)
 {
@@ -376,6 +495,8 @@ METHOD(chapoly_drv_t, init, bool,
        this->r[3] = (ru32(key +  9) >> 6) & 0x3f03fff;
        this->r[4] = (ru32(key + 12) >> 8) & 0x00fffff;
 
+       make_u(this);
+
        /* h = 0 */
        memwipe(this->h, sizeof(this->h));
 
@@ -388,42 +509,17 @@ METHOD(chapoly_drv_t, init, bool,
 }
 
 /**
- * r[127:64] = h[95:64] * a, r[63:0] = h[31:0] * b
+ * Update Poly1305 for a multiple of two blocks
  */
-static inline __m128i mul2(__m128i h, u_int32_t a, u_int32_t b)
-{
-       return _mm_mul_epu32(h, _mm_set_epi32(0, a, 0, b));
-}
-
-/**
- * c = a[127:64] + a[63:0] + b[127:64] + b[63:0]
- * z = x[127:64] + x[63:0] + y[127:64] + y[63:0]
- */
-static inline void sum2(__m128i a, __m128i b, __m128i x, __m128i y,
-                                               u_int64_t *c, u_int64_t *z)
-{
-       __m128i r, s;
-
-       a = _mm_add_epi64(a, b);
-       x = _mm_add_epi64(x, y);
-       r = _mm_unpacklo_epi64(x, a);
-       s = _mm_unpackhi_epi64(x, a);
-       r = _mm_add_epi64(r, s);
-
-       _mm_storel_epi64((__m128i*)z, r);
-       _mm_storel_epi64((__m128i*)c, _mm_srli_si128(r, 8));
-}
-
-METHOD(chapoly_drv_t, poly, bool,
-       private_chapoly_drv_ssse3_t *this, u_char *data, u_int blocks)
+static void poly2(private_chapoly_drv_ssse3_t *this, u_char *data, u_int dblks)
 {
-       u_int32_t r0, r1, r2, r3, r4;
-       u_int32_t s1, s2, s3, s4;
+       u_int32_t r0, r1, r2, r3, r4, u0, u1, u2, u3, u4;
+       u_int32_t s1, s2, s3, s4, v1, v2, v3, v4;
+       __m128i hc0, hc1, hc2, hc3, hc4;
        u_int32_t h0, h1, h2, h3, h4;
+       u_int32_t c0, c1, c2, c3, c4;
        u_int64_t d0, d1, d2, d3, d4;
-       __m128i h01, h23, h44;
-       __m128i x0, x1, y0, y1, z0;
-       u_int32_t t0, t1;
+       u_int i;
 
        r0 = this->r[0];
        r1 = this->r[1];
@@ -436,54 +532,74 @@ METHOD(chapoly_drv_t, poly, bool,
        s3 = r3 * 5;
        s4 = r4 * 5;
 
+       u0 = this->u[0];
+       u1 = this->u[1];
+       u2 = this->u[2];
+       u3 = this->u[3];
+       u4 = this->u[4];
+
+       v1 = u1 * 5;
+       v2 = u2 * 5;
+       v3 = u3 * 5;
+       v4 = u4 * 5;
+
        h0 = this->h[0];
        h1 = this->h[1];
        h2 = this->h[2];
        h3 = this->h[3];
        h4 = this->h[4];
 
-       while (blocks--)
+       /* h = (h + c1) * r^2 + c2 * r */
+       for (i = 0; i < dblks; i++)
        {
-               h01 = _mm_set_epi32(0, h0, 0, h1);
-               h23 = _mm_set_epi32(0, h2, 0, h3);
-               h44 = _mm_set_epi32(0, h4, 0, h4);
-
                /* h += m[i] */
-               t0  = (ru32(data +  0) >> 0) & 0x3ffffff;
-               t1  = (ru32(data +  3) >> 2) & 0x3ffffff;
-               h01 = _mm_add_epi32(h01, _mm_set_epi32(0, t0, 0, t1));
-               t0  = (ru32(data +  6) >> 4) & 0x3ffffff;
-               t1  = (ru32(data +  9) >> 6) & 0x3ffffff;
-               h23 = _mm_add_epi32(h23, _mm_set_epi32(0, t0, 0, t1));
-               t0  = (ru32(data + 12) >> 8) | (1 << 24);
-               h44 = _mm_add_epi32(h44, _mm_set_epi32(0, t0, 0, t0));
-
-               /* h *= r */
-               x0 = mul2(h01, r0, s4);
-               x1 = mul2(h01, r1, r0);
-               y0 = mul2(h23, s3, s2);
-               y1 = mul2(h23, s4, s3);
-               z0 = mul2(h44, s1, s2);
-               y0 = _mm_add_epi64(y0, _mm_srli_si128(z0, 8));
-               y1 = _mm_add_epi64(y1, _mm_slli_si128(z0, 8));
-               sum2(x0, y0, x1, y1, &d0, &d1);
-
-               x0 = mul2(h01, r2, r1);
-               x1 = mul2(h01, r3, r2);
-               y0 = mul2(h23, r0, s4);
-               y1 = mul2(h23, r1, r0);
-               z0 = mul2(h44, s3, s4);
-               y0 = _mm_add_epi64(y0, _mm_srli_si128(z0, 8));
-               y1 = _mm_add_epi64(y1, _mm_slli_si128(z0, 8));
-               sum2(x0, y0, x1, y1, &d2, &d3);
-
-               x0 = mul2(h01, r4, r3);
-               y0 = mul2(h23, r2, r1);
-               z0 = mul2(h44, r0, 0);
-               y0 = _mm_add_epi64(y0, z0);
-               x0 = _mm_add_epi64(x0, y0);
-               x0 = _mm_add_epi64(x0, _mm_srli_si128(x0, 8));
-               _mm_storel_epi64((__m128i*)&d4, x0);
+               h0 += (ru32(data +  0) >> 0) & 0x3ffffff;
+               h1 += (ru32(data +  3) >> 2) & 0x3ffffff;
+               h2 += (ru32(data +  6) >> 4) & 0x3ffffff;
+               h3 += (ru32(data +  9) >> 6) & 0x3ffffff;
+               h4 += (ru32(data + 12) >> 8) | (1 << 24);
+               data += POLY_BLOCK_SIZE;
+
+               /* c = m[i + 1] */
+               c0 = (ru32(data +  0) >> 0) & 0x3ffffff;
+               c1 = (ru32(data +  3) >> 2) & 0x3ffffff;
+               c2 = (ru32(data +  6) >> 4) & 0x3ffffff;
+               c3 = (ru32(data +  9) >> 6) & 0x3ffffff;
+               c4 = (ru32(data + 12) >> 8) | (1 << 24);
+               data += POLY_BLOCK_SIZE;
+
+               hc0 = _mm_set_epi32(0, h0, 0, c0);
+               hc1 = _mm_set_epi32(0, h1, 0, c1);
+               hc2 = _mm_set_epi32(0, h2, 0, c2);
+               hc3 = _mm_set_epi32(0, h3, 0, c3);
+               hc4 = _mm_set_epi32(0, h4, 0, c4);
+
+               /* h = h * r^2 + c * r */
+               d0 = sum5(mul2(hc0, u0, r0),
+                                 mul2(hc1, v4, s4),
+                                 mul2(hc2, v3, s3),
+                                 mul2(hc3, v2, s2),
+                                 mul2(hc4, v1, s1));
+               d1 = sum5(mul2(hc0, u1, r1),
+                                 mul2(hc1, u0, r0),
+                                 mul2(hc2, v4, s4),
+                                 mul2(hc3, v3, s3),
+                                 mul2(hc4, v2, s2));
+               d2 = sum5(mul2(hc0, u2, r2),
+                                 mul2(hc1, u1, r1),
+                                 mul2(hc2, u0, r0),
+                                 mul2(hc3, v4, s4),
+                                 mul2(hc4, v3, s3));
+               d3 = sum5(mul2(hc0, u3, r3),
+                                 mul2(hc1, u2, r2),
+                                 mul2(hc2, u1, r1),
+                                 mul2(hc3, u0, r0),
+                                 mul2(hc4, v4, s4));
+               d4 = sum5(mul2(hc0, u4, r4),
+                                 mul2(hc1, u3, r3),
+                                 mul2(hc2, u2, r2),
+                                 mul2(hc3, u1, r1),
+                                 mul2(hc4, u0, r0));
 
                /* (partial) h %= p */
                d1 += sr(d0, 26);     h0 = and(d0, 0x3ffffff);
@@ -492,8 +608,6 @@ METHOD(chapoly_drv_t, poly, bool,
                d4 += sr(d3, 26);     h3 = and(d3, 0x3ffffff);
                h0 += sr(d4, 26) * 5; h4 = and(d4, 0x3ffffff);
                h1 += h0 >> 26;       h0 = h0 & 0x3ffffff;
-
-               data += POLY_BLOCK_SIZE;
        }
 
        this->h[0] = h0;
@@ -501,7 +615,102 @@ METHOD(chapoly_drv_t, poly, bool,
        this->h[2] = h2;
        this->h[3] = h3;
        this->h[4] = h4;
+}
+
+/**
+ * Update Poly1305 for a single block
+ */
+static void poly1(private_chapoly_drv_ssse3_t *this, u_char *data)
+{
+       u_int32_t r0, r1, r2, r3, r4;
+       u_int32_t s1, s2, s3, s4;
+       u_int32_t h0, h1, h2, h3, h4;
+       u_int64_t d0, d1, d2, d3, d4;
+       __m128i h01, h23, h44;
+       __m128i x0, x1, y0, y1, z0;
+       u_int32_t t0, t1;
+
+       r0 = this->r[0];
+       r1 = this->r[1];
+       r2 = this->r[2];
+       r3 = this->r[3];
+       r4 = this->r[4];
+
+       s1 = r1 * 5;
+       s2 = r2 * 5;
+       s3 = r3 * 5;
+       s4 = r4 * 5;
+
+       h0 = this->h[0];
+       h1 = this->h[1];
+       h2 = this->h[2];
+       h3 = this->h[3];
+       h4 = this->h[4];
+
+       h01 = _mm_set_epi32(0, h0, 0, h1);
+       h23 = _mm_set_epi32(0, h2, 0, h3);
+       h44 = _mm_set_epi32(0, h4, 0, h4);
+
+       /* h += m[i] */
+       t0  = (ru32(data +  0) >> 0) & 0x3ffffff;
+       t1  = (ru32(data +  3) >> 2) & 0x3ffffff;
+       h01 = _mm_add_epi32(h01, _mm_set_epi32(0, t0, 0, t1));
+       t0  = (ru32(data +  6) >> 4) & 0x3ffffff;
+       t1  = (ru32(data +  9) >> 6) & 0x3ffffff;
+       h23 = _mm_add_epi32(h23, _mm_set_epi32(0, t0, 0, t1));
+       t0  = (ru32(data + 12) >> 8) | (1 << 24);
+       h44 = _mm_add_epi32(h44, _mm_set_epi32(0, t0, 0, t0));
+
+       /* h *= r */
+       x0 = mul2(h01, r0, s4);
+       x1 = mul2(h01, r1, r0);
+       y0 = mul2(h23, s3, s2);
+       y1 = mul2(h23, s4, s3);
+       z0 = mul2(h44, s1, s2);
+       y0 = _mm_add_epi64(y0, _mm_srli_si128(z0, 8));
+       y1 = _mm_add_epi64(y1, _mm_slli_si128(z0, 8));
+       sum2(x0, y0, x1, y1, &d0, &d1);
+
+       x0 = mul2(h01, r2, r1);
+       x1 = mul2(h01, r3, r2);
+       y0 = mul2(h23, r0, s4);
+       y1 = mul2(h23, r1, r0);
+       z0 = mul2(h44, s3, s4);
+       y0 = _mm_add_epi64(y0, _mm_srli_si128(z0, 8));
+       y1 = _mm_add_epi64(y1, _mm_slli_si128(z0, 8));
+       sum2(x0, y0, x1, y1, &d2, &d3);
+
+       x0 = mul2(h01, r4, r3);
+       y0 = mul2(h23, r2, r1);
+       z0 = mul2(h44, r0, 0);
+       y0 = _mm_add_epi64(y0, z0);
+       x0 = _mm_add_epi64(x0, y0);
+       x0 = _mm_add_epi64(x0, _mm_srli_si128(x0, 8));
+       _mm_storel_epi64((__m128i*)&d4, x0);
+
+       /* (partial) h %= p */
+       d1 += sr(d0, 26);     h0 = and(d0, 0x3ffffff);
+       d2 += sr(d1, 26);     h1 = and(d1, 0x3ffffff);
+       d3 += sr(d2, 26);     h2 = and(d2, 0x3ffffff);
+       d4 += sr(d3, 26);     h3 = and(d3, 0x3ffffff);
+       h0 += sr(d4, 26) * 5; h4 = and(d4, 0x3ffffff);
+       h1 += h0 >> 26;       h0 = h0 & 0x3ffffff;
+
+       this->h[0] = h0;
+       this->h[1] = h1;
+       this->h[2] = h2;
+       this->h[3] = h3;
+       this->h[4] = h4;
+}
 
+METHOD(chapoly_drv_t, poly, bool,
+       private_chapoly_drv_ssse3_t *this, u_char *data, u_int blocks)
+{
+       poly2(this, data, blocks / 2);
+       if (blocks-- % 2)
+       {
+               poly1(this, data + POLY_BLOCK_SIZE * blocks);
+       }
        return TRUE;
 }
 
@@ -517,19 +726,17 @@ METHOD(chapoly_drv_t, chacha, bool,
 METHOD(chapoly_drv_t, encrypt, bool,
        private_chapoly_drv_ssse3_t *this, u_char *data, u_int blocks)
 {
-       u_int i;
-
        while (blocks >= 4)
        {
                chacha_4block_xor(this, data);
-               poly(this, data, 16);
+               poly2(this, data, 8);
                data += CHACHA_BLOCK_SIZE * 4;
                blocks -= 4;
        }
-       for (i = 0; i < blocks; i++)
+       while (blocks--)
        {
                chacha_block_xor(this, data);
-               poly(this, data, 4);
+               poly2(this, data, 2);
                data += CHACHA_BLOCK_SIZE;
        }
        return TRUE;
@@ -538,18 +745,16 @@ METHOD(chapoly_drv_t, encrypt, bool,
 METHOD(chapoly_drv_t, decrypt, bool,
        private_chapoly_drv_ssse3_t *this, u_char *data, u_int blocks)
 {
-       u_int i;
-
        while (blocks >= 4)
        {
-               poly(this, data, 16);
+               poly2(this, data, 8);
                chacha_4block_xor(this, data);
                data += CHACHA_BLOCK_SIZE * 4;
                blocks -= 4;
        }
-       for (i = 0; i < blocks; i++)
+       while (blocks--)
        {
-               poly(this, data, 4);
+               poly2(this, data, 2);
                chacha_block_xor(this, data);
                data += CHACHA_BLOCK_SIZE;
        }
@@ -619,6 +824,7 @@ METHOD(chapoly_drv_t, destroy, void,
        memwipe(this->m, sizeof(this->m));
        memwipe(this->h, sizeof(this->h));
        memwipe(this->r, sizeof(this->r));
+       memwipe(this->u, sizeof(this->u));
        memwipe(this->s, sizeof(this->s));
        free_align(this);
 }