added aes-ni support

This commit is contained in:
Logan007 2020-09-01 02:09:31 +05:45
parent c90c8e6d43
commit 81a1ccc702
2 changed files with 339 additions and 12 deletions

View File

@ -60,6 +60,16 @@ typedef struct aes_context_t {
AES_KEY dec_key; /* tx key */
} aes_context_t;
#elif defined (__AES__) && defined (__SSE2__) // Intel's AES-NI ---------------------------
#include <immintrin.h>
typedef struct aes_context_t {
__m128i rk_enc[15];
__m128i rk_dec[15];
int Nr;
} aes_context_t;
#else // plain C --------------------------------------------------------------------------
typedef struct aes_context_t {
@ -68,7 +78,7 @@ typedef struct aes_context_t {
int Nr; // number of rounds
} aes_context_t;
#endif
#endif // ---------------------------------------------------------------------------------
int aes_cbc_encrypt (unsigned char *out, const unsigned char *in, size_t in_len,

339
src/aes.c
View File

@ -235,6 +235,284 @@ int aes_init (const unsigned char *key, size_t key_size, aes_context_t **ctx) {
}
#elif defined (__AES__) && defined (__SSE2__) // Intel's AES-NI ---------------------------
// inspired by https://gist.github.com/acapola/d5b940da024080dfaf5f
// furthered by the help of Sebastian Ramacher's implementation found at
// https://chromium.googlesource.com/external/github.com/dlitz/pycrypto/+/junk/master/src/AESNI.c
// modified along Intel's white paper on AES Instruction Set
// https://www.intel.com/content/dam/doc/white-paper/advanced-encryption-standard-new-instructions-set-paper.pdf
static __m128i aes128_keyexpand(__m128i key, __m128i keygened, uint8_t shuf) {
key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
// unfortunately, shuffle expects immediate argument ... macrorize???!!!
switch (shuf) {
case 0x55:
keygened = _mm_shuffle_epi32(keygened, 0x55 );
break;
case 0xaa:
keygened = _mm_shuffle_epi32(keygened, 0xaa );
break;
case 0xff:
keygened = _mm_shuffle_epi32(keygened, 0xff );
break;
default:
break;
}
return _mm_xor_si128(key, keygened);
}
static __m128i aes192_keyexpand_2(__m128i key, __m128i key2)
{
key = _mm_shuffle_epi32(key, 0xff);
key2 = _mm_xor_si128(key2, _mm_slli_si128(key2, 4));
return _mm_xor_si128(key, key2);
}
#define KEYEXP128(K, I) aes128_keyexpand(K, _mm_aeskeygenassist_si128(K, I), 0xff)
#define KEYEXP192(K1, K2, I) aes128_keyexpand(K1, _mm_aeskeygenassist_si128(K2, I), 0x55)
#define KEYEXP192_2(K1, K2) aes192_keyexpand_2(K1, K2)
#define KEYEXP256(K1, K2, I) aes128_keyexpand(K1, _mm_aeskeygenassist_si128(K2, I), 0xff)
#define KEYEXP256_2(K1, K2) aes128_keyexpand(K1, _mm_aeskeygenassist_si128(K2, 0x00), 0xaa)
// key setup
static int aes_internal_key_setup (aes_context_t *ctx, const uint8_t *key, int key_bits) {
// number of rounds
ctx->Nr = 6 + (key_bits / 32);
// encryption keys
switch (key_bits) {
case 128: {
ctx->rk_enc[0] = _mm_loadu_si128((const __m128i*)key);
ctx->rk_enc[1] = KEYEXP128(ctx->rk_enc[0], 0x01);
ctx->rk_enc[2] = KEYEXP128(ctx->rk_enc[1], 0x02);
ctx->rk_enc[3] = KEYEXP128(ctx->rk_enc[2], 0x04);
ctx->rk_enc[4] = KEYEXP128(ctx->rk_enc[3], 0x08);
ctx->rk_enc[5] = KEYEXP128(ctx->rk_enc[4], 0x10);
ctx->rk_enc[6] = KEYEXP128(ctx->rk_enc[5], 0x20);
ctx->rk_enc[7] = KEYEXP128(ctx->rk_enc[6], 0x40);
ctx->rk_enc[8] = KEYEXP128(ctx->rk_enc[7], 0x80);
ctx->rk_enc[9] = KEYEXP128(ctx->rk_enc[8], 0x1B);
ctx->rk_enc[10] = KEYEXP128(ctx->rk_enc[9], 0x36);
break;
}
case 192: {
__m128i temp[2];
ctx->rk_enc[0] = _mm_loadu_si128((const __m128i*) key);
ctx->rk_enc[1] = _mm_loadu_si128((const __m128i*) (key+16));
temp[0] = KEYEXP192(ctx->rk_enc[0], ctx->rk_enc[1], 0x01);
temp[1] = KEYEXP192_2(temp[0], ctx->rk_enc[1]);
ctx->rk_enc[1] = (__m128i)_mm_shuffle_pd((__m128d)ctx->rk_enc[1], (__m128d)temp[0], 0);
ctx->rk_enc[2] = (__m128i)_mm_shuffle_pd((__m128d)temp[0], (__m128d)temp[1], 1);
ctx->rk_enc[3] = KEYEXP192(temp[0], temp[1], 0x02);
ctx->rk_enc[4] = KEYEXP192_2(ctx->rk_enc[3], temp[1]);
temp[0] = KEYEXP192(ctx->rk_enc[3], ctx->rk_enc[4], 0x04);
temp[1] = KEYEXP192_2(temp[0], ctx->rk_enc[4]);
ctx->rk_enc[4] = (__m128i)_mm_shuffle_pd((__m128d)ctx->rk_enc[4], (__m128d)temp[0], 0);
ctx->rk_enc[5] = (__m128i)_mm_shuffle_pd((__m128d)temp[0], (__m128d)temp[1], 1);
ctx->rk_enc[6] = KEYEXP192(temp[0], temp[1], 0x08);
ctx->rk_enc[7] = KEYEXP192_2(ctx->rk_enc[6], temp[1]);
temp[0] = KEYEXP192(ctx->rk_enc[6], ctx->rk_enc[7], 0x10);
temp[1] = KEYEXP192_2(temp[0], ctx->rk_enc[7]);
ctx->rk_enc[7] = (__m128i)_mm_shuffle_pd((__m128d)ctx->rk_enc[7], (__m128d)temp[0], 0);
ctx->rk_enc[8] = (__m128i)_mm_shuffle_pd((__m128d)temp[0], (__m128d)temp[1], 1);
ctx->rk_enc[9] = KEYEXP192(temp[0], temp[1], 0x20);
ctx->rk_enc[10] = KEYEXP192_2(ctx->rk_enc[9], temp[1]);
temp[0] = KEYEXP192(ctx->rk_enc[9], ctx->rk_enc[10], 0x40);
temp[1] = KEYEXP192_2(temp[0], ctx->rk_enc[10]);
ctx->rk_enc[10] = (__m128i)_mm_shuffle_pd((__m128d)ctx->rk_enc[10], (__m128d) temp[0], 0);
ctx->rk_enc[11] = (__m128i)_mm_shuffle_pd((__m128d)temp[0],(__m128d) temp[1], 1);
ctx->rk_enc[12] = KEYEXP192(temp[0], temp[1], 0x80);
break;
}
case 256: {
ctx->rk_enc[0] = _mm_loadu_si128((const __m128i*) key);
ctx->rk_enc[1] = _mm_loadu_si128((const __m128i*) (key+16));
ctx->rk_enc[2] = KEYEXP256(ctx->rk_enc[0], ctx->rk_enc[1], 0x01);
ctx->rk_enc[3] = KEYEXP256_2(ctx->rk_enc[1], ctx->rk_enc[2]);
ctx->rk_enc[4] = KEYEXP256(ctx->rk_enc[2], ctx->rk_enc[3], 0x02);
ctx->rk_enc[5] = KEYEXP256_2(ctx->rk_enc[3], ctx->rk_enc[4]);
ctx->rk_enc[6] = KEYEXP256(ctx->rk_enc[4], ctx->rk_enc[5], 0x04);
ctx->rk_enc[7] = KEYEXP256_2(ctx->rk_enc[5], ctx->rk_enc[6]);
ctx->rk_enc[8] = KEYEXP256(ctx->rk_enc[6], ctx->rk_enc[7], 0x08);
ctx->rk_enc[9] = KEYEXP256_2(ctx->rk_enc[7], ctx->rk_enc[8]);
ctx->rk_enc[10] = KEYEXP256(ctx->rk_enc[8], ctx->rk_enc[9], 0x10);
ctx->rk_enc[11] = KEYEXP256_2(ctx->rk_enc[9], ctx->rk_enc[10]);
ctx->rk_enc[12] = KEYEXP256(ctx->rk_enc[10], ctx->rk_enc[11], 0x20);
ctx->rk_enc[13] = KEYEXP256_2(ctx->rk_enc[11], ctx->rk_enc[12]);
ctx->rk_enc[14] = KEYEXP256(ctx->rk_enc[12], ctx->rk_enc[13], 0x40);
break;
}
}
// derive decryption keys
for (int i = 1; i < ctx->Nr; ++i) {
ctx->rk_dec[ctx->Nr - i] = _mm_aesimc_si128(ctx->rk_enc[i]);
}
ctx->rk_dec[0] = ctx->rk_enc[ctx->Nr];
return ctx->Nr;
}
static void aes_internal_encrypt (aes_context_t *ctx, const uint8_t pt[16], uint8_t ct[16]) {
__m128i tmp = _mm_loadu_si128((__m128i*)pt);
tmp = _mm_xor_si128 (tmp, ctx->rk_enc[ 0]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 1]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 2]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 3]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 4]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 5]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 6]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 7]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 8]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[ 9]);
if(ctx->Nr > 10) {
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[10]);
tmp = _mm_aesenc_si128 (tmp, ctx->rk_enc[11]);
if(ctx->Nr > 12) {
tmp = _mm_aesenc_si128(tmp, ctx->rk_enc[12]);
tmp = _mm_aesenc_si128(tmp, ctx->rk_enc[13]);
}
}
tmp = _mm_aesenclast_si128(tmp, ctx->rk_enc[ctx->Nr]);
_mm_storeu_si128((__m128i*) ct, tmp);
}
static void aes_internal_decrypt (aes_context_t *ctx, const uint8_t ct[16], uint8_t pt[16]) {
__m128i tmp = _mm_loadu_si128((__m128i*)ct);
tmp = _mm_xor_si128 (tmp, ctx->rk_dec[ 0]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 1]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 2]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 3]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 4]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 5]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 6]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 7]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 8]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[ 9]);
if(ctx->Nr > 10) {
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[10]);
tmp = _mm_aesdec_si128 (tmp, ctx->rk_dec[11]);
if(ctx->Nr > 12) {
tmp = _mm_aesdec_si128(tmp, ctx->rk_dec[12]);
tmp = _mm_aesdec_si128(tmp, ctx->rk_dec[13]);
}
}
tmp = _mm_aesdeclast_si128(tmp, ctx->rk_enc[0]);
_mm_storeu_si128((__m128i*) pt, tmp);
}
// public API
int aes_ecb_decrypt (unsigned char *out, const unsigned char *in, aes_context_t *ctx) {
aes_internal_decrypt(ctx, in, out);
return AES_BLOCK_SIZE;
}
// not used
int aes_ecb_encrypt (unsigned char *out, const unsigned char *in, aes_context_t *ctx) {
aes_internal_encrypt(ctx, in, out);
return AES_BLOCK_SIZE;
}
#define fix_xor(target, source) *(uint32_t*)&(target)[0] = *(uint32_t*)&(target)[0] ^ *(uint32_t*)&(source)[0]; *(uint32_t*)&(target)[4] = *(uint32_t*)&(target)[4] ^ *(uint32_t*)&(source)[4]; \
*(uint32_t*)&(target)[8] = *(uint32_t*)&(target)[8] ^ *(uint32_t*)&(source)[8]; *(uint32_t*)&(target)[12] = *(uint32_t*)&(target)[12] ^ *(uint32_t*)&(source)[12];
int aes_cbc_encrypt (unsigned char *out, const unsigned char *in, size_t in_len,
const unsigned char *iv, aes_context_t *ctx) {
uint8_t tmp[AES_BLOCK_SIZE];
size_t i;
size_t n;
memcpy(tmp, iv, AES_BLOCK_SIZE);
n = in_len / AES_BLOCK_SIZE;
for(i=0; i < n; i++) {
fix_xor(tmp, &in[i * AES_BLOCK_SIZE]);
aes_internal_encrypt(ctx, tmp, tmp);
memcpy(&out[i * AES_BLOCK_SIZE], tmp, AES_BLOCK_SIZE);
}
return n * AES_BLOCK_SIZE;
}
int aes_cbc_decrypt (unsigned char *out, const unsigned char *in, size_t in_len,
const unsigned char *iv, aes_context_t *ctx) {
uint8_t tmp[AES_BLOCK_SIZE];
uint8_t old[AES_BLOCK_SIZE];
size_t i;
size_t n;
memcpy(tmp, iv, AES_BLOCK_SIZE);
n = in_len / AES_BLOCK_SIZE;
for(i=0; i < n; i++) {
memcpy(old, &in[i * AES_BLOCK_SIZE], AES_BLOCK_SIZE);
aes_internal_decrypt(ctx, &in[i * AES_BLOCK_SIZE], &out[i * AES_BLOCK_SIZE]);
fix_xor(&out[i * AES_BLOCK_SIZE], tmp);
memcpy(tmp, old, AES_BLOCK_SIZE);
}
return n * AES_BLOCK_SIZE;
}
int aes_init (const unsigned char *key, size_t key_size, aes_context_t **ctx) {
// allocate context...
*ctx = (aes_context_t*) calloc(1, sizeof(aes_context_t));
if (!(*ctx))
return -1;
// ...and fill her up
// initialize data structures
// check key size and make key size (given in bytes) dependant settings
switch(key_size) {
case AES128_KEY_BYTES: // 128 bit key size
break;
case AES192_KEY_BYTES: // 192 bit key size
break;
case AES256_KEY_BYTES: // 256 bit key size
break;
default:
traceEvent(TRACE_ERROR, "aes_init invalid key size %u\n", key_size);
return -1;
}
// key materiel handling
aes_internal_key_setup ( *ctx, key, 8 * key_size);
return 0;
}
#else // plain C --------------------------------------------------------------------------
// rijndael-alg-fst.c version 3.0 (December 2000)
@ -624,6 +902,7 @@ static const uint32_t rcon[] = {
* @return the number of rounds for the given cipher key size.
*/
static int aes_internal_key_setup_enc (uint32_t rk[/*4*(Nr + 1)*/], const uint8_t cipherKey[], int keyBits) {
int i = 0;
uint32_t temp;
@ -998,31 +1277,69 @@ int aes_deinit (aes_context_t *ctx) {
#endif
return 0;
} */
}
#ifdef TEST_AES
int main () {
uint32_t rk[60];
uint8_t key[32] = {0};
aes_context_t *ctx;
// *ctx = malloc(sizeof(aes_context_t));
// uint8_t key[32] = {0};
// 128 bit key 0 --> 0336763e966d92595a567cc9ce537f5e
// uint8_t pt[16] = {0xf3, 0x44, 0x81, 0xec, 0x3c, 0xc6, 0x27, 0xba,
// 0xcd, 0x5d, 0xc3, 0xfb, 0x08, 0xf2, 0x73, 0xe6 };
uint8_t pt[16] = {0x01, 0x47, 0x30, 0xf8, 0x0a, 0xc6, 0x25, 0xfe,
0x84, 0xf0, 0x26, 0xc6, 0x0b, 0xfd, 0x54, 0x7d };
// 256 bit key 0 --> 5c9d844ed46f9885085e5d6a4f94c7d7
// uint8_t pt[16] = {0x01, 0x47, 0x30, 0xf8, 0x0a, 0xc6, 0x25, 0xfe,
// 0x84, 0xf0, 0x26, 0xc6, 0x0b, 0xfd, 0x54, 0x7d };
uint8_t pt[16] = {0};
// 0 pt --> 6d251e6944b051e04eaa6fb4dbf78465
uint8_t key[16] = {0x10, 0xa5, 0x88, 0x69, 0xd7, 0x4b, 0xe5, 0xa3,
0x74, 0xcf, 0x86, 0x7c, 0xfb, 0x47, 0x38, 0x59 };
uint8_t ct[16] = {0};
int i;
i = aes_internal_key_setup_enc(rk/*[4*(Nr + 1)]*/, key, 8 * sizeof(key));
printf ("i = %u\n",i);
aes_internal_encrypt(rk, i, pt, ct);
i = aes_internal_key_setup_dec(rk/*[4*(Nr + 1)]*/, key, 8 * sizeof(key));
// aes_internal_key_setup (ctx, key, 8 * sizeof(key));
aes_init (key, sizeof(key), &ctx);
printf ("Nr = %u\n",(ctx)->Nr);
memset (pt, 0, 16);
aes_internal_decrypt(rk, i, ct, pt);
for(i = 0; i < 16; i++)
printf ("%02x",pt[i]);
printf ("\n");
printf ("--- pt\n");
aes_internal_encrypt((ctx), pt, ct);
memset (pt, 4, 16);
for(i = 0; i < 16; i++)
printf ("%02x",ct[i]);
printf ("--- ct\n");
printf ("Nr = %u\n",(ctx)->Nr);
printf ("Nr = %u\n",(ctx)->Nr);
aes_internal_decrypt((ctx), ct, pt);
memset (ct, 9, 16);
for(i = 0; i < 16; i++)
printf ("%02x",pt[i]);
printf ("--- pt\n");
aes_internal_encrypt((ctx), pt, ct);
for(i = 0; i < 16; i++)
printf ("%02x",ct[i]);
printf ("--- ct\n");
}
#endif
*/