C++: Can this be optimized?

Started by
16 comments, last by cache_hit 14 years, 7 months ago
Greetings, I'm working on an AES implementation (http://csrc.nist.gov/publications/fips/fips197/fips-197.pdf). It works fine and I'm using it in the implementation of some other signal processing. Now, I'm trying to optimize the most cycle intesive parts of it. It boils down to the polynomial multiplication that's eating up over half of all my cycles for my entire signal processing. Here's the code snippet:
uint8_t Aes::Mult(uint8_t inputA, uint8_t inputB)
{
    uint16_t temp = 0;

    // Polynomial multiplication.
    for (uint16_t bitCnt = 0; bitCnt < 8; bitCnt++)
    {
        if ((inputA >> bitCnt) & 0x01)
        {
            temp ^= (inputB << bitCnt);
        }
    }

    // Modulo reduction.
    for (uint16_t bitCnt = 15; bitCnt > 7; bitCnt--)
    {
        if ((temp >> bitCnt) & 0x01)
        {
            temp ^= (m_IrreduciblePolynomial << (bitCnt - 8));
        }
    }

    // Result should now fit within a byte after the modulo reduction.
    return static_cast<uint8_t>(temp);
}
I can't come up with any other tricks. The compiler (for an embedded platform) must be doing a pretty decent job, because anything I try just makes it worse. Any insights would be appreciated. Thanks!
Advertisement
Use a lookup table that maps every number in the range 0-255 to a list of integers in the range 0-7. The numbers in this list represent positions of 1-bits. So, take the number 147. This is 10010011. So your lookup table for entry 147 would contain the following array: {0, 1, 4, 7}. This comes from the fact that entries 0, 1, 4, and 7 contain 1 bits. (For the record, I'm amazed that the positions of 1 bits are spell out the number in decimal. What a strange coincidence!)

Then the loop becomes the following.

int* indices = mapping[inputA];for (int i=0; i < listlength; ++i){   temp ^= (inputB << indices);}


Of course we need to know how many items are in the list since it's different for every number. So make a *second* array that maps each integer in the range 0-255 to a single integer that is the number of 1 bits. This way in your multi-dimensional array all the inner arrays can have exactly 8 items so it's rectangular. The code then becomes this:

int length = lengths[inputA];int* indices = mapping[inputA];for (int i=0; i < length; ++i){   temp ^= (inputB << mapping);}



This removes one of the conditionals, but still has another conditional in it (the loop termination test), that is executed for every item in the list. We'd like to remove that. For this you can use a switch statement with a fallthrough. In a more general case you'd use Duff's Device, but since there's only a max of 8 possible values of listlength, it's easier.

int length = lengths[inputA];int* indices = mapping[inputA];switch (length){case 8:   temp ^= (inputB << *indices++);case 7:   temp ^= (inputB << *indices++);case 6:   temp ^= (inputB << *indices++);case 5:   temp ^= (inputB << *indices++);case 4:   temp ^= (inputB << *indices++);case 3:   temp ^= (inputB << *indices++);case 2:   temp ^= (inputB << *indices++);case 1:   temp ^= (inputB << *indices++);}


This is basically just an unrolled loop. Now you've gotten the entire loop, which previously had 16 conditionals, down to a single conditional just to determine where in the switch statement to jump to.

Do the same thing for the other loop.
unsigned char FFMul(unsigned char a, unsigned char b) {   unsigned char aa = a, bb = b, r = 0, t;   while (aa != 0) {      if ((aa & 1) != 0)         r = r ^ bb;      t = bb & 0x80;      bb = bb << 1;      if (t != 0)         bb = bb ^ 0x1b;      aa = aa >> 1;    }   return r;}


Source.

On a PC, this version takes half the time. Anything more is impossible to compare, branch and memory performance between PC and embedded platforms simply differ too much.


Also, depending on memory - if speed really is at premium, consider making a pre-calculated table. It's 64kilobytes, but then multiplication will be just a memory access - 'mult = tab[a];'
Depending on your target platform/compiler, doing arithmetic on 8/16 bit integers will cause lots of useless masking and sign extension. If the platform has 32 bit registers, use 32 bit ints for all internal operations even if member variables/parameters want to be 8/16 bit.
Can't you just precompute a 64-Kbyte table with all the results? I doubt you can do faster than that.
Quote:Original post by alvaro
Can't you just precompute a 64-Kbyte table with all the results? I doubt you can do faster than that.


Since it's an embedded environment, this probably wouldn't be possible. I'd be surprised if the method I gave above wasn't significantly faster though. Precomputing the whole table is obviously the fastest though if possible.
Precomputing a 64kB table would be prohibitive. The system only has 1MB internal memory which is divided into cache (256kB), BIOS/drivers, and general use (about 650kB). Putting the table into internal memory would eat up 10% of available resources, and putting it in external memory could trash what little (and very precious) cache I have.

@ Antheus: very nice link. I'll have to see what the compiler does with that. Also, your link has a second solution, FFMulFast(), that uses two 256 byte lookup tables and a couple conditional checks.
uint8_t FFMulFast(uint8_t a, uint8_t b){   uint16_t t = 0;   if (a == 0 || b == 0) return 0;   t = L[a] + L;   if (t > 255) t = t - 255;   return E[t];}

Where L and E are the two 256 byte LUTs. That should prove to be even better. I can easily support tables that size.

I'll try the various methods tomorrow and let you guys know what happens. Thanks everyone!
I implemented three versions of this multiplication under MSVC2008.
Version 1 - my original implementation
Version 2 - Antheus version
Version 3 - second method provided by the link Antheus posted (uses 2, 256-byte LUTs)

Here are the results:

DEBUG
-----
Correctness Test: Same = 65536, Different = 0

Performance Test (10000000 iterations):
Method 1: Output (ignore) = 10, Duration = 2.56581 seconds.
Method 2: Output (ignore) = 130, Duration = 1.91427 seconds.
Method 3: Output (ignore) = 84, Duration = 0.785167 seconds.


RELEASE
-------
Correctness Test: Same = 65536, Different = 0

Performance Test (10000000 iterations):
Method 1: Output (ignore) = 172, Duration = 1.44255 seconds.
Method 2: Output (ignore) = 25, Duration = 1.04715 seconds.
Method 3: Output (ignore) = 21, Duration = 0.388543 seconds.


The first pass, the Correctness Test, makes sure that for all inputs A (0-255) and B (0-255), the outputs of all three methods are the same. The performance tests iterate through each method 10,000,000 times. Each method has a unique set of input and output vectors (10 million A and B inputs for each method). To keep the optimizer from cheating, I output a random selection from each output vector (the "Output (ignore)" in the print status).

Built and ran the tests under both Debug and Release builds (standard Debug and Release settings, nothing changed from the default MSVC2008 settings for those).

I'll do these same tests under my embedded environment and post what I get there. Here's the sample code I used under MSVC2008.
#include "stdafx.h"#include "windows.h"#include <ctime>#include <iostream>unsigned short m_IrreduciblePolynomial = 0x011b;unsigned char Exp[256] = {0x01, 0x03, 0x05, 0x0f, 0x11, 0x33, 0x55, 0xff, 0x1a, 0x2e, 0x72, 0x96, 0xa1, 0xf8, 0x13, 0x35,0x5f, 0xe1, 0x38, 0x48, 0xd8, 0x73, 0x95, 0xa4, 0xf7, 0x02, 0x06, 0x0a, 0x1e, 0x22, 0x66, 0xaa, 0xe5, 0x34, 0x5c, 0xe4, 0x37, 0x59, 0xeb, 0x26, 0x6a, 0xbe, 0xd9, 0x70, 0x90, 0xab, 0xe6, 0x31, 0x53, 0xf5, 0x04, 0x0c, 0x14, 0x3c, 0x44, 0xcc, 0x4f, 0xd1, 0x68, 0xb8, 0xd3, 0x6e, 0xb2, 0xcd, 0x4c, 0xd4, 0x67, 0xa9, 0xe0, 0x3b, 0x4d, 0xd7, 0x62, 0xa6, 0xf1, 0x08, 0x18, 0x28, 0x78, 0x88, 0x83, 0x9e, 0xb9, 0xd0, 0x6b, 0xbd, 0xdc, 0x7f, 0x81, 0x98, 0xb3, 0xce, 0x49, 0xdb, 0x76, 0x9a, 0xb5, 0xc4, 0x57, 0xf9, 0x10, 0x30, 0x50, 0xf0, 0x0b, 0x1d, 0x27, 0x69, 0xbb, 0xd6, 0x61, 0xa3, 0xfe, 0x19, 0x2b, 0x7d, 0x87, 0x92, 0xad, 0xec, 0x2f, 0x71, 0x93, 0xae, 0xe9, 0x20, 0x60, 0xa0, 0xfb, 0x16, 0x3a, 0x4e, 0xd2, 0x6d, 0xb7, 0xc2, 0x5d, 0xe7, 0x32, 0x56, 0xfa, 0x15, 0x3f, 0x41, 0xc3, 0x5e, 0xe2, 0x3d, 0x47, 0xc9, 0x40, 0xc0, 0x5b, 0xed, 0x2c, 0x74, 0x9c, 0xbf, 0xda, 0x75, 0x9f, 0xba, 0xd5, 0x64, 0xac, 0xef, 0x2a, 0x7e, 0x82, 0x9d, 0xbc, 0xdf, 0x7a, 0x8e, 0x89, 0x80, 0x9b, 0xb6, 0xc1, 0x58, 0xe8, 0x23, 0x65, 0xaf, 0xea, 0x25, 0x6f, 0xb1, 0xc8, 0x43, 0xc5, 0x54, 0xfc, 0x1f, 0x21, 0x63, 0xa5, 0xf4, 0x07, 0x09, 0x1b, 0x2d, 0x77, 0x99, 0xb0, 0xcb, 0x46, 0xca, 0x45, 0xcf, 0x4a, 0xde, 0x79, 0x8b, 0x86, 0x91, 0xa8, 0xe3, 0x3e, 0x42, 0xc6, 0x51, 0xf3, 0x0e, 0x12, 0x36, 0x5a, 0xee, 0x29, 0x7b, 0x8d, 0x8c, 0x8f, 0x8a, 0x85, 0x94, 0xa7, 0xf2, 0x0d, 0x17, 0x39, 0x4b, 0xdd, 0x7c, 0x84, 0x97, 0xa2, 0xfd, 0x1c, 0x24, 0x6c, 0xb4, 0xc7, 0x52, 0xf6, 0x01};unsigned char Log[256] = {0x00, 0x00, 0x19, 0x01, 0x32, 0x02, 0x1a, 0xc6, 0x4b, 0xc7, 0x1b, 0x68, 0x33, 0xee, 0xdf, 0x03, 0x64, 0x04, 0xe0, 0x0e, 0x34, 0x8d, 0x81, 0xef, 0x4c, 0x71, 0x08, 0xc8, 0xf8, 0x69, 0x1c, 0xc1, 0x7d, 0xc2, 0x1d, 0xb5, 0xf9, 0xb9, 0x27, 0x6a, 0x4d, 0xe4, 0xa6, 0x72, 0x9a, 0xc9, 0x09, 0x78, 0x65, 0x2f, 0x8a, 0x05, 0x21, 0x0f, 0xe1, 0x24, 0x12, 0xf0, 0x82, 0x45, 0x35, 0x93, 0xda, 0x8e, 0x96, 0x8f, 0xdb, 0xbd, 0x36, 0xd0, 0xce, 0x94, 0x13, 0x5c, 0xd2, 0xf1, 0x40, 0x46, 0x83, 0x38, 0x66, 0xdd, 0xfd, 0x30, 0xbf, 0x06, 0x8b, 0x62, 0xb3, 0x25, 0xe2, 0x98, 0x22, 0x88, 0x91, 0x10, 0x7e, 0x6e, 0x48, 0xc3, 0xa3, 0xb6, 0x1e, 0x42, 0x3a, 0x6b, 0x28, 0x54, 0xfa, 0x85, 0x3d, 0xba, 0x2b, 0x79, 0x0a, 0x15, 0x9b, 0x9f, 0x5e, 0xca, 0x4e, 0xd4, 0xac, 0xe5, 0xf3, 0x73, 0xa7, 0x57, 0xaf, 0x58, 0xa8, 0x50, 0xf4, 0xea, 0xd6, 0x74, 0x4f, 0xae, 0xe9, 0xd5, 0xe7, 0xe6, 0xad, 0xe8, 0x2c, 0xd7, 0x75, 0x7a, 0xeb, 0x16, 0x0b, 0xf5, 0x59, 0xcb, 0x5f, 0xb0, 0x9c, 0xa9, 0x51, 0xa0, 0x7f, 0x0c, 0xf6, 0x6f, 0x17, 0xc4, 0x49, 0xec, 0xd8, 0x43, 0x1f, 0x2d, 0xa4, 0x76, 0x7b, 0xb7, 0xcc, 0xbb, 0x3e, 0x5a, 0xfb, 0x60, 0xb1, 0x86, 0x3b, 0x52, 0xa1, 0x6c, 0xaa, 0x55, 0x29, 0x9d, 0x97, 0xb2, 0x87, 0x90, 0x61, 0xbe, 0xdc, 0xfc, 0xbc, 0x95, 0xcf, 0xcd, 0x37, 0x3f, 0x5b, 0xd1, 0x53, 0x39, 0x84, 0x3c, 0x41, 0xa2, 0x6d, 0x47, 0x14, 0x2a, 0x9e, 0x5d, 0x56, 0xf2, 0xd3, 0xab, 0x44, 0x11, 0x92, 0xd9, 0x23, 0x20, 0x2e, 0x89, 0xb4, 0x7c, 0xb8, 0x26, 0x77, 0x99, 0xe3, 0xa5, 0x67, 0x4a, 0xed, 0xde, 0xc5, 0x31, 0xfe, 0x18, 0x0d, 0x63, 0x8c, 0x80, 0xc0, 0xf7, 0x70, 0x07};unsigned char Mult1(unsigned char inputA, unsigned char inputB){    //////////////	// Method 1 //	//////////////	unsigned short temp = 0;    // Polynomial multiplication.    for (unsigned short bitCnt = 0; bitCnt < 8; bitCnt++)    {        if ((inputA >> bitCnt) & 0x01)        {            temp ^= (inputB << bitCnt);        }    }    // Modulo reduction.    for (unsigned short bitCnt = 15; bitCnt > 7; bitCnt--)    {        if ((temp >> bitCnt) & 0x01)        {            temp ^= (m_IrreduciblePolynomial << (bitCnt - 8));        }    }    // Result should now fit within a byte after the modulo reduction.    return static_cast<unsigned char>(temp);}unsigned char Mult2(unsigned char inputA, unsigned char inputB){    //////////////	// Method 2 //	//////////////	unsigned char aa = inputA, bb = inputB, r = 0, t;		while (aa != 0)	{		if ((aa & 1) != 0)			r = r ^ bb;		t = bb & 0x80;		bb = bb << 1;		if (t != 0)			bb = bb ^ 0x1b;				aa = aa >> 1;    }   	return r;}unsigned char Mult3(unsigned char inputA, unsigned char inputB){    //////////////	// Method 3 //	//////////////	unsigned short t = 0;	if (inputA == 0 || inputB == 0)		return 0;	t = Log[inputA] + Log[inputB];	if (t > 255)		t = t - 255;	return Exp[t];}int _tmain(int argc, _TCHAR* argv[]){	LARGE_INTEGER proc_freq;	QueryPerformanceFrequency(&proc_freq);	double frequency = (1.0 / static_cast<double>(proc_freq.QuadPart));		LARGE_INTEGER start;	LARGE_INTEGER stop;	double diff;	unsigned char output1;	unsigned char output2;	unsigned char output3;	int same = 0;	int different = 0;	int iterations = 10000000;	unsigned char* inputArrayA1 = new unsigned char[iterations];	unsigned char* inputArrayA2 = new unsigned char[iterations];	unsigned char* inputArrayA3 = new unsigned char[iterations];	unsigned char* inputArrayB1 = new unsigned char[iterations];	unsigned char* inputArrayB2 = new unsigned char[iterations];	unsigned char* inputArrayB3 = new unsigned char[iterations];	unsigned char* outputArray1 = new unsigned char[iterations];	unsigned char* outputArray2 = new unsigned char[iterations];	unsigned char* outputArray3 = new unsigned char[iterations];	srand(static_cast<unsigned int>(time(NULL)));	for (int i = 0; i < iterations; i++)	{		inputArrayA1 = rand();		inputArrayA2 = rand();		inputArrayA3 = rand();		inputArrayB1 = rand();		inputArrayB2 = rand();		inputArrayB3 = rand();	}		// Correctness test.	for (int i = 0; i < 256; i++)	{		for (int j = 0; j < 256; j++)		{			output1 = Mult1(static_cast<unsigned char>(i), static_cast<unsigned char>(j));			output2 = Mult2(static_cast<unsigned char>(i), static_cast<unsigned char>(j));			output3 = Mult3(static_cast<unsigned char>(i), static_cast<unsigned char>(j));			if ((output1 == output2) && (output2 == output3))			{				same++;			}			else			{				different++;			}		}	}	std::cout << "Correctness Test: Same = " << same << ", Different = " << different << std::endl << std::endl;			// Performance test.	std::cout << "Performance Test (" << iterations << " iterations):" << std::endl;	QueryPerformanceCounter(&start);	for (int i = 0; i < iterations; i++)	{		outputArray1 = Mult1(inputArrayA1, inputArrayB1);	}	QueryPerformanceCounter(&stop);	diff = (stop.QuadPart - start.QuadPart) * frequency;	std::cout << "Method 1: Output (ignore) = " << static_cast<unsigned short>(outputArray1[rand() % iterations]) << ", Duration = " << diff << " seconds." << std::endl;	QueryPerformanceCounter(&start);	for (int i = 0; i < iterations; i++)	{		outputArray2 = Mult2(inputArrayA2, inputArrayB2);	}	QueryPerformanceCounter(&stop);	diff = (stop.QuadPart - start.QuadPart) * frequency;	std::cout << "Method 2: Output (ignore) = " << static_cast<unsigned short>(outputArray2[rand() % iterations]) << ", Duration = " << diff << " seconds." << std::endl;	QueryPerformanceCounter(&start);	for (int i = 0; i < iterations; i++)	{		outputArray3 = Mult3(inputArrayA3, inputArrayB3);	}	QueryPerformanceCounter(&stop);	diff = (stop.QuadPart - start.QuadPart) * frequency;	std::cout << "Method 3: Output (ignore) = " << static_cast<unsigned short>(outputArray3[rand() % iterations]) << ", Duration = " << diff << " seconds." << std::endl;	delete[] outputArray1;	delete[] outputArray2;	delete[] outputArray3;	delete[] inputArrayA1;	delete[] inputArrayA2;	delete[] inputArrayA3;	delete[] inputArrayB1;	delete[] inputArrayB2;	delete[] inputArrayB3;	return 0;}

Some interesting results on the embedded side of things. Running all three methods through my embedded testbench, I get the following:

Method 1: Mult() cycles = 8,134,008, total system run-time = 23.157ms
Method 2: Mult() cycles = 3,912,233, total system run-time = 16.124ms
Method 3: Mult() cycles = 5,823,237, total system run-time = 19.307ms

As you can see, this multiplication was taking a large portion of my total system run-time. Of the 23.157ms, Mult() was taking 13.550ms on Method 1 which is well over half the total run-time. Method 2 reduces this to 6.517ms and Method 3 reduces it to 9.700ms

On the embedded system, Method 3 is actually slower than Method 2. Looking at the generated assembly, the compiler created the same number of instructions, but was able to pipeline the instructions more for Method 2 than Method 3. A little surprising, but nice, since Method 2 uses less memory than Method 3.


Thanks for your input, everyone!
All of the new methods still use a lot of nested conditionals. I'd be interested to see how the method I posted performs, although it may no longer be a bottleneck with your new implementation and not worth testing further. The 2 lookup tables I mentioned are pretty easy to pre-compute with a simple program that counts 1-bits and spits them out. Then just copy/paste it into your source.

Also, compilers usually optimize if-then statements for the "true" case. That is, you will have more successful branch predictions if the conditions evaluate to true more often than they evaluate to false. You could profile how often they evaluate to true vs false and make sure they're written correctly. If you do a profile-guided optimization in Visual Studio, I think it will do this for you.

This topic is closed to new replies.

Advertisement