My question is a follow-up to How to make this code faster (learning best practices)?, which has been put on hold (bummer). The problem is to optimize a loop over an array with floats which are tested for whether they lie within a given interval. Indices of matching elements in the array are to be stored in a provided result array.
The test includes two conditions (smaller than the upper threshold and bigger than the lower one). The obvious code for the test is if( elem <= upper && elem >= lower ) .... I observed that branching (including the implicit branch involved in the short-circuiting operator&&) is much more expensive than the second comparison. What I came up with is below. It is about 20%-40% faster than a naive implementation, more than I expected. It uses the fact that bool is an integer type. The condition test result is used as an index into two result arrays. Only one of them will contain the desired data, the other one can be discarded. This replaces program structure with data structure and computation.
I am interested in more ideas for optimization. "Technical hacks" (of the kind provided here) are welcome. I'm also interested in whether modern C++ could provide means to be faster, e.g. by enabling the compiler to create parallel running code. Think visitor pattern/functor. Computations on the single srcArr elements are almost independent, except that the order of indices in the result array depends on the order of testing the source array elements. I would loosen the requirements a little so that the order of the matching indices reported in the result array is irrelevant. Can anybody come up with a fast way?
Here is the source code of the function. A supporting main is below. gcc needs -std=c++11 because of chrono. VS 2013 express was able to compile this too (and created 40% faster code than gcc -O3).
#include <cstdlib>
#include <iostream>
#include <chrono>
using namespace std;
using namespace std::chrono;
/// Check all elements in srcArr whether they lie in
/// the interval [lower, upper]. Store the indices of
/// such elements in the array pointed to by destArr[1]
/// and return the number of matching elements found.
/// This has been highly optimized, mainly to avoid branches.
int findElemsInInterval( const float srcArr[], // contains candidates
int **const destArr, // two arrays to be filled with indices
const int arrLen, // length of each array
const float lower, const float upper // interval
)
{
// Instead of branching, use the condition
// as an index into two distinct arrays. We need to keep
// separate indices for both those arrays.
int destIndices[2];
destIndices[0] = destIndices[1] = 0;
for( int srcInd=0; srcInd<arrLen; ++srcInd )
{
// If the element is inside the interval, both conditions
// are true and therefore equal. In all other cases
// exactly one condition is true so that they are not equal.
// Matching elements' indices are therefore stored in destArr[1].
// destArr[0] is a kind of a dummy (it will incidentally contain
// indices of non-matching elements).
// This used to be (with a simple int *destArr)
// if( srcArr[srcInd] <= upper && srcArr[srcInd] >= lower) destArr[destIndex++] = srcInd;
int isInInterval = (srcArr[srcInd] <= upper) == (srcArr[srcInd] >= lower);
destArr[isInInterval][destIndices[isInInterval]++] = srcInd;
}
return destIndices[1]; // the number of elements in the results array
}
int main(int argc, char *argv[])
{
int arrLen = 1000*1000*100;
if( argc > 1 ) arrLen = atol(argv[1]);
// destArr[1] will hold the indices of elements which
// are within the interval.
int *destArr[2];
// we don't check destination boundaries, so make them
// the same length as the source.
destArr[0] = new int[arrLen];
destArr[1] = new int[arrLen];
float *srcArr = new float[arrLen];
// Create always the same numbers for comparison (don't srand).
for( int srcInd=0; srcInd<arrLen; ++srcInd ) srcArr[srcInd] = rand();
// Create an interval in the middle of the rand() spectrum
float lowerLimit = RAND_MAX/3;
float upperLimit = lowerLimit*2;
cout << "lower = " << lowerLimit << ", upper = " << upperLimit << endl;
int numInterval;
auto t1 = high_resolution_clock::now(); // measure clock time as an approximation
// Call the function a few times to get a longer run time
for( int srcInd=0; srcInd<10; ++srcInd )
numInterval = findElemsInInterval( srcArr, destArr, arrLen, lowerLimit, upperLimit );
auto t2 = high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>( t2 - t1 ).count();
cout << numInterval << " elements found in " << duration << " milliseconds. " << endl;
return 0;
}