my SIMD implementation is very slow :(

Started by
18 comments, last by Aressera 9 years, 7 months ago

Instead of using double-indirection [...]


It is actually a quadruple (4x) indirection: 1. idToIntersect[] 2. m_triBuffer[] 3. LocalTri->indiceX 4. m_vertices[]
This is incredibly bad because it "amplifies" your memory problems. CPUs always try to execute instructions either in parallel or at least overlapping if they don't depend on each other. For float addition, you need at least 3 independent (vector) additions (with AVX that is 3x8) to fully saturate an ivybridge ALU. For multiplication it's 5 independent ones. The problem with pointer chasing like the above indirection is that everything depends on the result of a long chain of operations, where in turn every operation depends on the previous one. Computations can not start until step (4), loading m_vertices[], has completed. That however can only start, when the index is known, which means that loading LocalTri->indiceX must fully complete. That again can only start after loading from m_triBuffer[] has fully completed, and so on. Until this chain of operations is done, most of your CPU is idle because there is nothing to execute in parallel.
If you are lucky, everything is in the L1 cache and every load can be serviced in 4 cycles. Then it takes a total of 16 cycles before actual computations start, and remember with AVX 16 cycles are worth 128 floating point operations. Now lets assume that your triangle sizes increase and all the stuff no longer fits into the L1 cache, but has to be loaded from the L2 cache. The L2 latency is 10 cycles, so just a 6 cycle increase, thats not a big deal. But since you have 4x indirection, you actually get 4x that 6 cycle increase. Assuming, that everything is in L2 of course. If your triangle sizes increase even more, you might have to load stuff from the L3. I don't know the latencies for the L3 but lets assume they are just 30 cycles. If all those loads hit the L3, then it takes you 4x30cycles = 120 cycles before you even start computing. That is almost 1000 floating point operations wasted.

Of course some of those loads will probably always hit the L1, but pointer chasing / indirection is extremely bad, and it can severely amplify the effects of cache misses. Getting rid of that would be even higher on my priority list than reducing the size of the data structures.


As a side note, wouldn't it be easier to not simd-vectorize the vector operations, but instead to perform the computations for 8 or 16 rays simultaneously?
Advertisement

Okay so I got some comparisons with my non-SIMD to SIMD code:

Implicit sphere based scene (so little memory access)

NO-SIMD: 15s SIMD: 11.5s (yay slight improvement!)

Large tribased scene (a tank, so lots of memory access)

NO-SIMD: 6.6s SIMD: 600 seconds !!!!!!!!!!

both of them show similar stall areas, bur very simply my SIMD version (everything is the same size) makes everything implode on tri based scenes and take significantly longer on stall prone areas. Which is something I really don't get! Non SIMD uses 4 floats x,y,z,w and the SIMD version uses __m128. Both produce the same images, so I know the code is accurate, but just very very slow.

@Ohforf sake

Uh oh sad.png I'll have a look at removing the indirection smile.png

With your suggestion of basically batch rays, would then the compiler auto optimise, what would a standard operation look like?

So like:

// 4 ray batch

ray1 = ray1 + anotherRay;

ray2 = ray2 + anotherRay;

ray3 = ray3 + anotherRay;

ray4 = ray4 + anotherRay;

and the addition operation is just (rather than fancy SIMD):

ray1.x = ray1.x + anotherRay.x

ray1.y = ray1.y + anotherRay.y

etc

Then I would assume, that if the rays do something completely different to each other (like reflect, and the other continues onwards or something) I would batch them up separately? (This approach would be better for GPUs too I guess?)

6 vs 600, that's somehow incredible if everything works at it should.
I'd rather expect a bug (e.g. your traversing checks don't early-out but traverse the whole tree every time due to some SIMD bug, maybe)

btw. have you enabled code generation for sse?

go to project->properties->c++->code generation_>enable enchanced instruction set and set it to SS2 or AVX (and floating point model to "Fast", of course)

I hope you've also enabled all the optimizations ;)

NO-SIMD: 6.6s SIMD: 600 seconds !!!!!!!!!!

Wow, that is not normal. Are there any other hot spots except for the one in the screenshot?

With your suggestion of basically batch rays, would then the compiler auto optimise, what would a standard operation look like?
So like:
...snip...

Yes that is what I meant, although I would try to hide it as far as possible behind classes with overloaded operators. So instead of a Ray class, you could have a FourRays class, or even better a templated MultiRays<4> class.

In my experience, the auto vectorization of the compiler only kicks in if you basically pre arranged everything. Like lots of
for (unsigned i = 0; i < 4; i++)
    result[i] = input1[i] + input2[i];
at which point, actually writing the intrinsics yourself makes the code a lot cleaner IMO. But if you put that into the implementations of overloaded operators, you can change it and experiment with it farily easily.

Then I would assume, that if the rays do something completely different to each other (like reflect, and the other continues onwards or something) I would batch them up separately? (This approach would be better for GPUs too I guess?)

Yes that is a problem, and one that both CPUs and GPUs share. On GPUs even more, because the SIMD width there is 32 (nVidia) and 64 (ATI). I have never writtten a path tracer, so I can't give any well founded advice, but maybe even for diverging rays you can extract some benefit from this kind of batching. Every ray still needs to perform ray-triangle intersection tests which are mostly the same, even when the triangles are different. You might end up having to write some clever shuffling mechanism for reading the triangle data into the vector registers but starting from Haswell, Intel supports scatter read (finally!) which you might be able to put to good use.

Yes that is a problem, and one that both CPUs and GPUs share. On GPUs even more, because the SIMD width there is 32 (nVidia) and 64 (ATI). I have never writtten a path tracer, so I can't give any well founded advice, but maybe even for diverging rays you can extract some benefit from this kind of batching. Every ray still needs to perform ray-triangle intersection tests which are mostly the same, even when the triangles are different. You might end up having to write some clever shuffling mechanism for reading the triangle data into the vector registers but starting from Haswell, Intel supports scatter read (finally!) which you might be able to put to good use.

Unfortunately that isn't really the case, at least on CPU ray batching/ray packet tracing has little benefit for incoherent rays and actually decreases performance (don't know about GPU). It's great for coherent rays though, like primary rays, you get a nice speedup (not linear, but definitely worth it if you're after performance). For instance, asking the Embree benchmark program for some hard numbers:


           coherent_intersect1 ... 4.550676 Mrps
           coherent_intersect4 ... 6.902795 Mrps
           coherent_intersect8 ... 9.849938 Mrps
         incoherent_intersect1 ... 2.163574 Mrps
         incoherent_intersect4 ... 1.940323 Mrps
         incoherent_intersect8 ... 1.989973 Mrps

Downside, of course, is batching complicates your program's logic, but even if you have a scalar renderer there are still places that can immediately benefit from it with minimum effort, e.g. camera rays for cheaper anti-aliasing, photon mapping first bounce, etc...

The hybrid approach is also possible (start with a batch of rays, and then break it up as soon as the rays go their separate ways) and there are lots of papers online that discuss it smile.png

“If I understand the standard right it is legal and safe to do this but the resulting value could be anything.”

Oops, my bad - it was a problem with my AABB checks, my SIMD version wasn't working correctly. I feel like such a mug, urgh! (thanks Krypt0n)

Its now roughly the same speed as non-simd, this is probably because I'm temporarily extracting the __m128 as a float[4] performing the scaler comparisons which I know is very costly for performance (but at least it damn works). I just need to work out a good way of SIMD aabb-ray checks :D

either way, thanks for the help guys - the tips on memory and costly indirections was really useful! (I hope I didn't waste your time, I should of checked my AABB earlier :/ )

Check out this code for my SIMD ray tracer. It will show you how to do the traversal and intersection tests fast. (also see Embree stuff, that's the reference I used). The important thing to note is that rather than tracing a ray packet vs 1 AABB, you flatten the tree and intersect 1 ray (4-wide) with 4 child AABBs at once. This is great for incoherent rays, since there is no logic needed to handle splitting up ray packets. My code provides a generic interface for arbitrary primitives, and also can cache certain primitive types (i.e. triangles) for faster access and better storage for SIMD. This code gets me around 10 million incoherent rays/s on an i7 4770k with 8 threads in a scene with 80k triangles (sibenik cathedral model).


//##########################################################################################
//##########################################################################################
//############		
//############		Fat SIMD Ray Class Declaration
//############		
//##########################################################################################
//##########################################################################################




class RIM_ALIGN(16) AABBTree4:: FatSIMDRay : public math::SIMDRay3D<Float32,4>
{
	public:
		
		//********************************************************************************
		//********************************************************************************
		//********************************************************************************
		//******	Constructors
			
			
			
			
			RIM_INLINE FatSIMDRay( const Ray3f& ray )
				:	math::SIMDRay3D<Float32,4>( ray ),
					inverseDirection( Float32(1) / ray.direction )
			{
				sign[0] = ray.direction.x < Float32(0);
				sign[1] = ray.direction.y < Float32(0);
				sign[2] = ray.direction.z < Float32(0);
			}
			
			
			
			
		//********************************************************************************
		//********************************************************************************
		//********************************************************************************
		//******	Public Data Members
			
			
			
			
			/// The inverse of the direction vector of this SIMD Ray.
			SIMDVector3f inverseDirection;
			
			
			
			
			/// Indices of the sign of the ray's direction along the 3 axes: 0 for positive, 1 for negative.
			/**
			  * The axes are enumerated: 0 = X, 1 = Y, 2 = Z.
			  */
			Index sign[3];
			
			
			
			
};




//##########################################################################################
//##########################################################################################
//############		
//############		Node Class Declaration
//############		
//##########################################################################################
//##########################################################################################




class RIM_ALIGN(128) AABBTree4:: Node
{
	public:
		
		//********************************************************************************
		//********************************************************************************
		//********************************************************************************
		//******	Constructors
			
			
			
			
			/// Create a new inner node with the specified child offsets and child AABBs.
			RIM_FORCE_INLINE Node( IndexType child0, IndexType child1, IndexType child2, IndexType child3,
									const StaticArray<AABB3f,4>& newAABB )
			{
				for ( Index i = 0; i < 4; i++ )
					setChildAABB( i, newAABB[i] );
				
				child[0] = this + child0;
				child[1] = this + child1;
				child[2] = this + child2;
				child[3] = this + child3;
			}
			
			
			
			
			/// Create a new leaf node for the specified primitive offset and primitive count.
			RIM_FORCE_INLINE Node( IndexType primitiveOffset, IndexType primitiveCount )
			{
				indices[0] = 0;
				indices[1] = primitiveOffset;
				indices[2] = primitiveCount;
			}
			
			
			
			
		//********************************************************************************
		//********************************************************************************
		//********************************************************************************
		//******	Leaf Node Attribute Accessor Methods
			
			
			
			
			/// Return whether or not this is a leaf node.
			RIM_FORCE_INLINE Bool isLeaf() const
			{
				return indices[0] == 0;
			}
			
			
			
			
			/// Return the offset in the primitive array of this leaf node's primitives.
			RIM_FORCE_INLINE IndexType getPrimitiveOffset() const
			{
				return indices[1];
			}
			
			
			
			
			/// Set the offset in the primitive array of this leaf node's primitives.
			RIM_FORCE_INLINE void setPrimitiveOffset( IndexType newOffset )
			{
				indices[1] = newOffset;
			}
			
			
			
			
			/// Return the number of primitives that are part of this leaf node.
			RIM_FORCE_INLINE IndexType getPrimitiveCount() const
			{
				return indices[2];
			}
			
			
			
			
			/// Set the number of primitives that are part of this leaf node.
			RIM_FORCE_INLINE void setPrimitiveCount( IndexType newCount )
			{
				indices[2] = newCount;
			}
			
			
			
			
		//********************************************************************************
		//********************************************************************************
		//********************************************************************************
		//******	Child Accessor Methods
			
			
			
			
			/// Return a pointer to the child
			RIM_FORCE_INLINE Node* getChild( Index i )
			{
				return child[i];
			}
			
			
			
			
			RIM_FORCE_INLINE const Node* getChild( Index i ) const
			{
				return child[i];
			}
			
			
			
			
			RIM_FORCE_INLINE void setChildAABB( Index i, const AABB3f& newAABB )
			{
				bounds[0][i] = newAABB.min.x;
				bounds[1][i] = newAABB.max.x;
				bounds[2][i] = newAABB.min.y;
				bounds[3][i] = newAABB.max.y;
				bounds[4][i] = newAABB.min.z;
				bounds[5][i] = newAABB.max.z;
			}
			
			
			
			
		//********************************************************************************
		//********************************************************************************
		//********************************************************************************
		//******	Ray Intersection Methods
			
			
			
			
			RIM_FORCE_INLINE void intersectRay( const FatSIMDRay& ray, SIMDFloat4& near, SIMDFloat4& far ) const
			{
				SIMDFloat4 txmin = (bounds[0 + ray.sign[0]] - ray.origin.x) * ray.inverseDirection.x;
				SIMDFloat4 txmax = (bounds[1 - ray.sign[0]] - ray.origin.x) * ray.inverseDirection.x;
				SIMDFloat4 tymin = (bounds[2 + ray.sign[1]] - ray.origin.y) * ray.inverseDirection.y;
				SIMDFloat4 tymax = (bounds[3 - ray.sign[1]] - ray.origin.y) * ray.inverseDirection.y;
				SIMDFloat4 tzmin = (bounds[4 + ray.sign[2]] - ray.origin.z) * ray.inverseDirection.z;
				SIMDFloat4 tzmax = (bounds[5 - ray.sign[2]] - ray.origin.z) * ray.inverseDirection.z;
				
				const SIMDFloat4 zero( 0.0f );
				const SIMDFloat4 negativeInfinity( math::negativeInfinity<float>() );
				
				near = math::max( math::max( txmin, tymin ), math::max( tzmin, zero ) );
				far = math::max( math::min( math::min( txmax, tymax ), tzmax ), negativeInfinity );
			}
			
			
			
			
		//********************************************************************************
		//********************************************************************************
		//********************************************************************************
		//******	Public Data Members
			
			
			
			/// A set of 4 SIMD axis-aligned bounding boxes for this quad node.
			/**
			  * The bounding boxes are stored in the following format:
			  *	- 0: xMin
			  * - 1: xMax
			  * - 2: yMin
			  * - 3: yMax
			  * - 4: zMin
			  * - 5: zMax
			  */
			SIMDFloat4 bounds[6];
			
			
			
			
			/// The indices of the first 3 child nodes of this node.
			/**
			  * By convention, the last child (index == 3) is always the next node after this one,
			  * so its index is not stored.
			  */
			union
			{
				Node* child[4];
				
				IndexType indices[4];
			};
			
			
			
};




//##########################################################################################
//##########################################################################################
//############		
//############		Primitive AABB Class Declaration
//############		
//##########################################################################################
//##########################################################################################




class RIM_ALIGN(16) AABBTree4:: PrimitiveAABB
{
	public:
		
		RIM_FORCE_INLINE PrimitiveAABB( const AABB3f& aabb, Index newPrimitiveIndex )
			:	min( aabb.min ),
				max( aabb.max ),
				primitiveIndex( newPrimitiveIndex )
		{
			centroid = (min + max)*Float(0.5);
		}
		
		
		
		
		/// The minimum coordinate of the primitive's axis-aligned bounding box.
		RIM_ALIGN(16) SIMDFloat4 min;
		
		
		
		
		/// The maximum coordinate of the primitive's axis-aligned bounding box.
		RIM_ALIGN(16) SIMDFloat4 max;
		
		
		
		
		/// The centroid of the primitive's axis-aligned bounding box.
		RIM_ALIGN(16) SIMDFloat4 centroid;
		
		
		
		
		/// The index of this primitive in the primitive set.
		Index primitiveIndex;
		
		
		
};




//##########################################################################################
//##########################################################################################
//############		
//############		Split Bin Class Declaration
//############		
//##########################################################################################
//##########################################################################################




class RIM_ALIGN(16) AABBTree4:: SplitBin
{
	public:
		
		RIM_INLINE SplitBin()
			:	min( math::max<Float32>() ),
				max( math::min<Float32>() ),
				numPrimitives( 0 )
		{
		}
		
		RIM_ALIGN(16) SIMDFloat4 min;
		RIM_ALIGN(16) SIMDFloat4 max;
		
		
		Size numPrimitives;
		
};




//##########################################################################################
//##########################################################################################
//############		
//############		Cached Triangle Class Declaration
//############		
//##########################################################################################
//##########################################################################################




class RIM_ALIGN(16) AABBTree4:: CachedTriangle
{
	public:
		
		RIM_INLINE CachedTriangle( const SIMDVector3f& newV0,
									const SIMDVector3f& newE1,
									const SIMDVector3f& newE2,
									const StaticArray<IndexType,4>& newIndices  )
			:	v0( newV0 ),
				e1( newE1 ),
				e2( newE2 )
		{
			indices[0] = newIndices[0];
			indices[1] = newIndices[1];
			indices[2] = newIndices[2];
			indices[3] = newIndices[3];
		}
		
		/// The vertex of this triangle with index 0.
		SIMDVector3f v0;
		
		/// The edge vector between vertex 0 and vertex 1.
		SIMDVector3f e1;
		
		/// The edge vector between vertex 0 and vertex 2.
		SIMDVector3f e2;
		
		/// The indices of the 4 packed triangles.
		IndexType indices[4];
		
		
};




//##########################################################################################
//##########################################################################################
//############		
//############		Constructors
//############		
//##########################################################################################
//##########################################################################################




AABBTree4:: AABBTree4()
	:	nodes( NULL ),
		primitiveData( NULL ),
		primitiveDataCapacity( 0 ),
		primitiveSet( NULL ),
		cachedPrimitiveType( PrimitiveInterfaceType::UNDEFINED ),
		numPrimitives( 0 ),
		numNodes( 0 ),
		maxDepth( 0 ),
		maxNumPrimitivesPerLeaf( DEFAULT_MAX_PRIMITIVES_PER_LEAF ),
		numSplitCandidates( DEFAULT_NUM_SPLIT_CANDIDATES )
{
}




AABBTree4:: AABBTree4( const Pointer<const PrimitiveInterface>& newPrimitives )
	:	nodes( NULL ),
		primitiveData( NULL ),
		primitiveDataCapacity( 0 ),
		primitiveSet( newPrimitives ),
		cachedPrimitiveType( PrimitiveInterfaceType::UNDEFINED ),
		numPrimitives( 0 ),
		numNodes( 0 ),
		maxDepth( 0 ),
		maxNumPrimitivesPerLeaf( DEFAULT_MAX_PRIMITIVES_PER_LEAF ),
		numSplitCandidates( DEFAULT_NUM_SPLIT_CANDIDATES )
{
}




AABBTree4:: AABBTree4( const AABBTree4& other )
	:	primitiveSet( other.primitiveSet ),
		cachedPrimitiveType( other.cachedPrimitiveType ),
		numPrimitives( other.numPrimitives ),
		numNodes( other.numNodes ),
		maxDepth( other.maxDepth ),
		maxNumPrimitivesPerLeaf( other.maxNumPrimitivesPerLeaf ),
		numSplitCandidates( other.numSplitCandidates )
{
	if ( numNodes > 0 )
		nodes = util::copyArrayAligned( other.nodes, other.numNodes, sizeof(Node) );
	else
		nodes = NULL;
	
	if ( numPrimitives > 0 )
		primitiveData = other.copyPrimitiveData( primitiveDataCapacity );
	else
	{
		primitiveData = NULL;
		primitiveDataCapacity = 0;
	}
}




//##########################################################################################
//##########################################################################################
//############		
//############		Destructor
//############		
//##########################################################################################
//##########################################################################################




AABBTree4:: ~AABBTree4()
{
	if ( nodes )
		util::deallocateAligned( nodes );
	
	if ( primitiveData )
		util::deallocateAligned( primitiveData );
}




//##########################################################################################
//##########################################################################################
//############		
//############		Assignment Operator
//############		
//##########################################################################################
//##########################################################################################





AABBTree4& AABBTree4:: operator = ( const AABBTree4& other )
{
	if ( this != &other )
	{
		if ( numNodes < other.numNodes )
		{
			if ( nodes )
				util::deallocateAligned( nodes );
			
			nodes = util::copyArrayAligned( other.nodes, other.numNodes, sizeof(Node) );
		}
		else if ( other.numNodes > 0 )
			rim::util::copy( nodes, other.nodes, other.numNodes );
		
		if ( primitiveData )
				util::deallocateAligned( primitiveData );
		
		if ( other.numPrimitives > 0 )
			primitiveData = other.copyPrimitiveData( primitiveDataCapacity );
		else
		{
			primitiveData = NULL;
			primitiveDataCapacity = 0;
		}
		
		primitiveSet = other.primitiveSet;
		numPrimitives = other.numPrimitives;
		numNodes = other.numNodes;
		maxDepth = other.maxDepth;
		maxNumPrimitivesPerLeaf = other.maxNumPrimitivesPerLeaf;
		numSplitCandidates = other.numSplitCandidates;
	}
	
	return *this;
}




//##########################################################################################
//##########################################################################################
//############		
//############		Primitive Accessor Methods
//############		
//##########################################################################################
//##########################################################################################




void AABBTree4:: setPrimitives( const Pointer<const PrimitiveInterface>& newPrimitives )
{
	primitiveSet = newPrimitives;
	
	// Set the number of nodes and primitives to 0 to signal that the BVH needs to be rebuilt.
	numNodes = 0;
	numPrimitives = 0;
}




const Pointer<const PrimitiveInterface>& AABBTree4:: getPrimitives() const
{
	return primitiveSet;
}




//##########################################################################################
//##########################################################################################
//############		
//############		BVH Attribute Accessor Methods
//############		
//##########################################################################################
//##########################################################################################




Size AABBTree4:: getMaxDepth() const
{
	return maxDepth;
}





Bool AABBTree4:: isValid() const
{
	return numNodes > 0;
}




//##########################################################################################
//##########################################################################################
//############		
//############		BVH Building Methods
//############		
//##########################################################################################
//##########################################################################################




void AABBTree4:: rebuild()
{
	const Size newNumPrimitives = primitiveSet->getSize();
	
	// Don't build the tree if there are no primitives.
	if ( primitiveSet.isNull() || newNumPrimitives == 0 )
		return;
	
	//**************************************************************************************
	
	// Allocate an array to hold the list of TriangleAABB primitives.
	PrimitiveAABB* primitiveAABBs = rim::util::allocateAligned<PrimitiveAABB>( newNumPrimitives, 16 );
	
	// Initialize all PrimitiveAABB primitives with the primitives for this tree.
	for ( Index i = 0; i < newNumPrimitives; i++ )
		new (primitiveAABBs + i) PrimitiveAABB( primitiveSet->getAABB(i), i );
	
	//**************************************************************************************
	
	Size numSplitBins = numSplitCandidates + 1;
	
	// Allocate a temporary array to hold the split bins.
	SplitBin* splitBins = rim::util::allocateAligned<SplitBin>( numSplitBins, 16 );
	
	//**************************************************************************************
	
	// Compute the number of nodes needed for this tree.
	Size newNumNodes = newNumPrimitives*Size(2) - 1;
	
	// Allocate space for the nodes in this tree.
	if ( newNumNodes > numNodes )
	{
		if ( nodes )
			rim::util::deallocateAligned( nodes );
		
		nodes = rim::util::allocateAligned<Node>( newNumNodes, sizeof(Node) );
	}
	
	numNodes = newNumNodes;
	
	// Build the tree, starting with the root node.
	buildTreeRecursive( nodes, primitiveAABBs, 0, newNumPrimitives,
						splitBins, numSplitBins, maxNumPrimitivesPerLeaf, 1, maxDepth );
	
	//**************************************************************************************
	
	// Determine if the BVH should cache the primitives based on their type.
	numPrimitives = newNumPrimitives;
	Size newPrimitiveDataSize = 0;
	
	switch ( primitiveSet->getType() )
	{
		case PrimitiveInterfaceType::TRIANGLES:
			newPrimitiveDataSize = getTriangleArraySize( nodes )*sizeof(CachedTriangle);
			break;
		default:
			newPrimitiveDataSize = numPrimitives*sizeof(Index);
			break;
	}
	
	// Allocate an array to hold the primitive data.
	if ( newPrimitiveDataSize > primitiveDataCapacity )
	{
		if ( primitiveData )
			rim::util::deallocateAligned( primitiveData );
		
		primitiveData = rim::util::allocateAligned<UByte>( newPrimitiveDataSize, 16 );
		primitiveDataCapacity = newPrimitiveDataSize;
	}
	
	// Copy the current order of the TriangleAABB list into the tree's list of primitive pointers.
	switch ( primitiveSet->getType() )
	{
		case PrimitiveInterfaceType::TRIANGLES:
			fillTriangleArray( (CachedTriangle*)primitiveData, primitiveSet, primitiveAABBs, nodes, 0 );
			cachedPrimitiveType = PrimitiveInterfaceType::TRIANGLES;
			break;
			
		default:
			fillPrimitiveIndices( (Index*)primitiveData, primitiveAABBs, numPrimitives );
			cachedPrimitiveType = PrimitiveInterfaceType::UNDEFINED;
			break;
	}
	
	//**************************************************************************************
	// Clean up the temporary arrays of TriangleAABB primitives and split bins.
	
	rim::util::deallocateAligned( primitiveAABBs );
	rim::util::deallocateAligned( splitBins );
}




void AABBTree4:: refit()
{
	if ( numNodes == 0 )
		return;
	
	// If the number or type of primitives has changed, rebuild the tree instead.
	if ( numPrimitives != primitiveSet->getSize() || cachedPrimitiveType != primitiveSet->getType() )
	{
		this->rebuild();
		return;
	}
	
	// Refit the tree for different kinds of primitives.
	switch ( cachedPrimitiveType )
	{
		case PrimitiveInterfaceType::TRIANGLES:
			this->refitTreeTriangles( nodes );
			break;
		default:
			this->refitTreeGeneric( nodes );
	}
}




//##########################################################################################
//##########################################################################################
//############		
//############		Ray Tracing Methods
//############		
//##########################################################################################
//##########################################################################################




static unsigned int bitCount( unsigned int mask )
{
#if defined(RIM_COMPILER_GCC)
	return __builtin_popcount( mask );
#elif defined(RIM_COMPILER_MSVC)
	return __popcnt( mask );
#else
	mask = mask - ((mask >> 1) & 0x55555555);
	mask = (mask & 0x33333333) + ((mask >> 2) & 0x33333333);
	return (((mask + (mask >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24;
#endif
}




static unsigned long firstSetBit( unsigned long mask )
{
#if defined(RIM_COMPILER_GCC)
	return __builtin_ctz( mask );
#elif defined(RIM_COMPILER_MSVC)
	unsigned long index;
	_BitScanForward( &index, mask );
	return index;
#else
	#error
#endif
}




static Int minIndex( const SIMDFloat4& x )
{
	const SIMDInt4 indices1( 0, 1, 2, 3 );
	const SIMDInt4 indices2( 2, 3, 0, 1 );
	
	// Shuffle the value once to find the minimum of 0 & 2, 1 & 3.
	SIMDFloat4 x2 = math::shuffle<2,3,0,1>( x );
	
	// Determine the indices of the values which are the minimum of 0 & 2, 1 & 3.
	SIMDInt4 indices3 = math::select( x < x2, indices1, indices2 );
	
	// Find the minimum of 0 & 2, 1 & 3.
	x2 = math::min( x, x2 );
	
	// Shuffle the values again to determine the minimum value.
	SIMDFloat4 x3 = math::shuffle<1,0,3,2>( x2 );
	
	// Compute the index of the closest intersection.
	SIMDInt4 minimumIndex = math::select( x2 < x3, indices3, math::shuffle<1,0,3,2>( indices3 ) );
	
	return minimumIndex[0];
}




static Int minIndex( const SIMDFloat4& x, SIMDFloat4& wideMin )
{
	const SIMDInt4 indices1( 0, 1, 2, 3 );
	const SIMDInt4 indices2( 2, 3, 0, 1 );
	
	// Shuffle the value once to find the minimum of 0 & 2, 1 & 3.
	SIMDFloat4 x2 = math::shuffle<2,3,0,1>( x );
	
	// Determine the indices of the values which are the minimum of 0 & 2, 1 & 3.
	SIMDInt4 indices3 = math::select( x < x2, indices1, indices2 );
	
	// Find the minimum of 0 & 2, 1 & 3.
	x2 = math::min( x, x2 );
	
	// Shuffle the values again to determine the minimum value.
	SIMDFloat4 x3 = math::shuffle<1,0,3,2>( x2 );
	
	// Compute the index of the closest intersection.
	SIMDInt4 minimumIndex = math::select( x2 < x3, indices3, math::shuffle<1,0,3,2>( indices3 ) );
	
	// Compute a 4-wide vector of the minimum value.
	wideMin = math::min( x2, x3 );
	
	return minimumIndex[0];
}




Bool AABBTree4:: traceRay( const Ray3f& newRay, Float maxDistance, TraversalStack& traversalStack,
							Float& closestIntersection, Index& closestPrimitive ) const
{
	if ( numNodes == 0 )
		return false;
	
	switch ( cachedPrimitiveType )
	{
		case PrimitiveInterfaceType::TRIANGLES:
			return this->traceRayVsTriangles( newRay, maxDistance, traversalStack, closestIntersection, closestPrimitive );
		
		default:
			return this->traceRayVsGeneric( newRay, maxDistance, traversalStack, closestIntersection, closestPrimitive );
	}
	
	return false;
}




Bool AABBTree4:: traceRay( const Ray3f& ray, Float maxDistance, TraversalStack& stack ) const
{
	Float d;
	Index primitiveIndex;
	
	return traceRay( ray, maxDistance, stack, d, primitiveIndex );
}




//##########################################################################################
//##########################################################################################
//############		
//############		Generic Ray Tracing Method
//############		
//##########################################################################################
//##########################################################################################




Bool AABBTree4:: traceRayVsGeneric( const Ray3f& newRay, Float maxDistance, TraversalStack& traversalStack,
									Float& closestIntersection, Index& closestPrimitive ) const
{
	const void** stackBase = traversalStack.getRoot();
	const void** stack = stackBase + 1;
	*stack = nodes;
	
	const PrimitiveInterface* const primitives = primitiveSet;
	const Index* const indices = (const Index*)primitiveData;
	
	FatSIMDRay ray( newRay );
	closestIntersection = maxDistance;
	SIMDFloat4 closestDistance( maxDistance );
	SIMDFloat4 near;
	SIMDFloat4 far;
	Float primitiveDistance;
	Index closestPrimitiveIndex;
	const Node* node = nodes;
	
	while ( true )
	{
		nextNode:
		
		if ( node->isLeaf() )
		{
			if ( primitives->intersectRay( indices + node->getPrimitiveOffset(),
											node->getPrimitiveCount(), newRay,
											primitiveDistance, closestPrimitiveIndex ) &&
				primitiveDistance < closestIntersection )
			{
				closestIntersection = primitiveDistance;
				closestDistance = SIMDFloat4( primitiveDistance );
				closestPrimitive = closestPrimitiveIndex;
			}
		}
		else
		{
			node->intersectRay( ray, near, far );
			
			SIMDBool4 intersectionResult = (near <= far) & (near < closestDistance);
			Int mask = intersectionResult.getMask();
			
			switch ( mask )
			{
				// No hits. Backtrack on the stack.
				case 0:
					break;
				
				// 1 Hit. Replace the current node with the hit child.
				case 1 << 0:	node = node->getChild(0);	goto nextNode;
				case 1 << 1:	node = node->getChild(1);	goto nextNode;
				case 1 << 2:	node = node->getChild(2);	goto nextNode;
				case 1 << 3:	node = node->getChild(3);	goto nextNode;
				
				// 2 Hits.
				case 3: // 0011
				{
					near = math::select( intersectionResult, near, closestDistance );
					stack++;
					
					if ( near[1] < near[0] )
					{
						*stack = node->getChild(0);
						node = node->getChild(1);
					}
					else
					{
						*stack = node->getChild(1);
						node = node->getChild(0);
					}
					goto nextNode;
				}
				
				case 5: // 0101
				{
					near = math::select( intersectionResult, near, closestDistance );
					stack++;
					
					if ( near[2] < near[0] )
					{
						*stack = node->getChild(0);
						node = node->getChild(2);
					}
					else
					{
						*stack = node->getChild(2);
						node = node->getChild(0);
					}
					goto nextNode;
				}
				
				case 6: // 0110
				{
					near = math::select( intersectionResult, near, closestDistance );
					stack++;
					
					if ( near[2] < near[1] )
					{
						*stack = node->getChild(1);
						node = node->getChild(2);
					}
					else
					{
						*stack = node->getChild(2);
						node = node->getChild(1);
					}
					goto nextNode;
				}
				
				case 9: // 1001
				{
					near = math::select( intersectionResult, near, closestDistance );
					stack++;
					
					if ( near[3] < near[0] )
					{
						*stack = node->getChild(0);
						node = node->getChild(3);
					}
					else
					{
						*stack = node->getChild(3);
						node = node->getChild(0);
					}
					goto nextNode;
				}
				
				case 10: // 1010
				{
					near = math::select( intersectionResult, near, closestDistance );
					stack++;
					
					if ( near[3] < near[1] )
					{
						*stack = node->getChild(1);
						node = node->getChild(3);
					}
					else
					{
						*stack = node->getChild(3);
						node = node->getChild(1);
					}
					goto nextNode;
				}
				
				case 12: // 1100
				{
					near = math::select( intersectionResult, near, closestDistance );
					stack++;
					
					if ( near[3] < near[2] )
					{
						*stack = node->getChild(2);
						node = node->getChild(3);
					}
					else
					{
						*stack = node->getChild(3);
						node = node->getChild(2);
					}
					goto nextNode;
				}
				
				
				default:
				{
					// There are more than 2 hit children.
					// Determine the index of the closest hit child.
					Int closestChildIndex = minIndex( math::select( intersectionResult, near, closestDistance ) );
					
					// Clear the bit of the closest hit child.
					mask &= ~(1 << closestChildIndex);
					
					//****************************************************
					// Second hit.
					
					Int i = firstSetBit( mask );
					
					// Put the child onto the stack.
					stack++;
					*stack = node->getChild(i);
					
					// Clear the bit.
					mask &= ~(1 << i);
					
					//****************************************************
					// Third hit.
					
					i = firstSetBit( mask );
					
					// Put the child onto the stack.
					stack++;
					*stack = node->getChild(i);
					
					// Clear the bit.
					mask &= ~(1 << i);
					
					//****************************************************
					// Fourth hit, if necessary.
					
					if ( mask )
					{
						i = firstSetBit( mask );
						
						// Put the child onto the stack.
						stack++;
						*stack = node->getChild(i);
						
						// Clear the bit.
						mask &= ~(1 << i);
					}
					
					// Determine the next node to traverse.
					node = node->getChild(closestChildIndex);
					goto nextNode;
				}
			}
		}
		
		node = (const Node*)*stack;
		stack--;
		
		if ( stack == stackBase )
			break;
	}
	
	// If the distance is less than the maximum distance which we started with, there was an intersection.
	return closestIntersection < maxDistance;
}




//##########################################################################################
//##########################################################################################
//############		
//############		Triangle Ray Tracing Method
//############		
//##########################################################################################
//##########################################################################################




Bool AABBTree4:: traceRayVsTriangles( const Ray3f& newRay, Float maxDistance, TraversalStack& traversalStack,
									Float& closestIntersection, Index& closestPrimitive ) const
{
	const void** stackBase = traversalStack.getRoot();
	const void** stack = stackBase + 1;
	*stack = nodes;
	
	const CachedTriangle* const triangles = (const CachedTriangle*)primitiveData;
	
	FatSIMDRay ray( newRay );
	closestIntersection = maxDistance;
	SIMDFloat4 closestDistance( maxDistance );
	SIMDFloat4 near;
	SIMDFloat4 far;
	SIMDFloat4 triangleDistance;
	const Node* node = nodes;
	
	while ( true )
	{
		nextNode:
		
		if ( node->isLeaf() )
		{
			const IndexType numNodePrimitives = node->getPrimitiveCount();
			
			if ( numNodePrimitives == 1 )
			{
				// Fast case for a single quad triangle.
				const CachedTriangle* triangle = triangles + node->getPrimitiveOffset();
				
				// Find the intersections.
				SIMDBool4 triangleMask = rayIntersectsTriangles( ray, *triangle, triangleDistance );
				triangleMask &= (triangleDistance < closestDistance);
				
				// Find the closest intersection index if there was an intersection.
				if ( triangleMask.getMask() )
				{
					// Set all non-intersecting triangles to have a very large distance
					// so that they won't affect the closest intersection computation.
					triangleDistance = math::select( triangleMask, triangleDistance, SIMDFloat4(math::max<Float>()) );
					
					Int minTIndex = minIndex( triangleDistance, closestDistance );
					
					closestPrimitive = triangle->indices[minTIndex];
					closestIntersection = closestDistance[minTIndex];
				}
			}
			else
			{
				// General case for many triangles.
				const CachedTriangle* triangle = triangles + node->getPrimitiveOffset();
				const CachedTriangle* const trianglesEnd = triangle + numNodePrimitives;
				
				while ( triangle != trianglesEnd )
				{
					// Compute the intersection distance for all 4 triangles.
					SIMDBool4 triangleMask = rayIntersectsTriangles( ray, *triangle, triangleDistance );
					triangleMask &= (triangleDistance < closestDistance);
					
					// Find the closest intersection index if there was an intersection.
					if ( triangleMask.getMask() )
					{
						// Set all non-intersecting triangles to have a very large distance
						// so that they won't affect the closest intersection computation.
						triangleDistance = math::select( triangleMask, triangleDistance, SIMDFloat4(math::max<Float>()) );
						
						Int minTIndex = minIndex( triangleDistance, closestDistance );
						
						closestPrimitive = triangle->indices[minTIndex];
						closestIntersection = closestDistance[minTIndex];
					}
					
					triangle++;
				}
			}
		}
		else
		{
			node->intersectRay( ray, near, far );
			
			SIMDBool4 intersectionResult = (near <= far) & (near < closestDistance);
			Int mask = intersectionResult.getMask();
			
			switch ( mask )
			{
				// No hits. Backtrack on the stack.
				case 0:
					break;
				
				// 1 Hit. Replace the current node with the hit child.
				case 1 << 0:	node = node->getChild(0);	goto nextNode;
				case 1 << 1:	node = node->getChild(1);	goto nextNode;
				case 1 << 2:	node = node->getChild(2);	goto nextNode;
				case 1 << 3:	node = node->getChild(3);	goto nextNode;
				
				// 2 Hits.
				case 3: // 0011
				{
					near = math::select( intersectionResult, near, closestDistance );
					stack++;
					
					if ( near[1] < near[0] )
					{
						*stack = node->getChild(0);
						node = node->getChild(1);
					}
					else
					{
						*stack = node->getChild(1);
						node = node->getChild(0);
					}
					goto nextNode;
				}
				
				case 5: // 0101
				{
					near = math::select( intersectionResult, near, closestDistance );
					stack++;
					
					if ( near[2] < near[0] )
					{
						*stack = node->getChild(0);
						node = node->getChild(2);
					}
					else
					{
						*stack = node->getChild(2);
						node = node->getChild(0);
					}
					goto nextNode;
				}
				
				case 6: // 0110
				{
					near = math::select( intersectionResult, near, closestDistance );
					stack++;
					
					if ( near[2] < near[1] )
					{
						*stack = node->getChild(1);
						node = node->getChild(2);
					}
					else
					{
						*stack = node->getChild(2);
						node = node->getChild(1);
					}
					goto nextNode;
				}
				
				case 9: // 1001
				{
					near = math::select( intersectionResult, near, closestDistance );
					stack++;
					
					if ( near[3] < near[0] )
					{
						*stack = node->getChild(0);
						node = node->getChild(3);
					}
					else
					{
						*stack = node->getChild(3);
						node = node->getChild(0);
					}
					goto nextNode;
				}
				
				case 10: // 1010
				{
					near = math::select( intersectionResult, near, closestDistance );
					stack++;
					
					if ( near[3] < near[1] )
					{
						*stack = node->getChild(1);
						node = node->getChild(3);
					}
					else
					{
						*stack = node->getChild(3);
						node = node->getChild(1);
					}
					goto nextNode;
				}
				
				case 12: // 1100
				{
					near = math::select( intersectionResult, near, closestDistance );
					stack++;
					
					if ( near[3] < near[2] )
					{
						*stack = node->getChild(2);
						node = node->getChild(3);
					}
					else
					{
						*stack = node->getChild(3);
						node = node->getChild(2);
					}
					goto nextNode;
				}
				
				
				default:
				{
					// There are more than 2 hit children.
					// Determine the index of the closest hit child.
					Int closestChildIndex = minIndex( math::select( intersectionResult, near, closestDistance ) );
					
					// Clear the bit of the closest hit child.
					mask &= ~(1 << closestChildIndex);
					
					//****************************************************
					// Second hit.
					
					Int i = firstSetBit( mask );
					
					// Put the child onto the stack.
					stack++;
					*stack = node->getChild(i);
					
					// Clear the bit.
					mask &= ~(1 << i);
					
					//****************************************************
					// Third hit.
					
					i = firstSetBit( mask );
					
					// Put the child onto the stack.
					stack++;
					*stack = node->getChild(i);
					
					// Clear the bit.
					mask &= ~(1 << i);
					
					//****************************************************
					// Fourth hit, if necessary.
					
					if ( mask )
					{
						i = firstSetBit( mask );
						
						// Put the child onto the stack.
						stack++;
						*stack = node->getChild(i);
						
						// Clear the bit.
						mask &= ~(1 << i);
					}
					
					// Determine the next node to traverse.
					node = node->getChild(closestChildIndex);
					goto nextNode;
				}
			}
		}
		
		node = (const Node*)*stack;
		stack--;
		
		if ( stack == stackBase )
			break;
	}
	
	// If the distance is less than the maximum distance which we started with, there was an intersection.
	return closestIntersection < maxDistance;
}




//##########################################################################################
//##########################################################################################
//############		
//############		Ray Vs. Triangle Intersection Method
//############		
//##########################################################################################
//##########################################################################################




SIMDBool4 AABBTree4:: rayIntersectsTriangles( const SIMDRay3f& ray, const CachedTriangle& triangle, SIMDFloat4& distance )
{
	// the vector perpendicular to edge 2 and the ray's direction
	SIMDVector3f pvec = math::cross( ray.direction, triangle.e2 );
	SIMDFloat4 det = math::dot( triangle.e1, pvec );
	
	// Do the first rejection test for the triangles, test to see if the ray is in the same plane as the triangle.
	SIMDBool4 result = math::abs(det) >= math::epsilon<Float>();
	
	//************************************************************************************
	
	SIMDFloat4 inverseDet = Float(1) / det;
	
	SIMDVector3f v0ToSource = ray.origin - triangle.v0;
	
	SIMDFloat4 u = math::dot( v0ToSource, pvec ) * inverseDet;
	
	// Do the second rejection test for the triangles. See if the UV coordinate is within the valid range.
	result &= (u >= Float(0)) & (u <= Float(1));
	
	//************************************************************************************
	
	SIMDVector3f qvec = math::cross( v0ToSource, triangle.e1 );
	
	SIMDFloat4 v = math::dot( ray.direction, qvec ) * inverseDet;
	
	// Do the third rejection test for the triangles. See if the UV coordinate is within the valid range.
	result &= (v >= Float(0)) & (u + v <= Float(1));
	
	//************************************************************************************
	
	distance = math::dot( triangle.e2, qvec ) * inverseDet;
	
	// Make sure that the triangles are hit by the forward side of the ray.
	return result & (distance > math::epsilon<Float>());
}




//##########################################################################################
//##########################################################################################
//############		
//############		Recursive Tree Construction Method
//############		
//##########################################################################################
//##########################################################################################




Size AABBTree4:: buildTreeRecursive( Node* node, PrimitiveAABB* primitiveAABBs,
										Index start, Size numPrimitives,
										SplitBin* splitBins, Size numSplitBins, 
										Size maxNumPrimitivesPerLeaf, Size depth, Size& maxDepth )
{
	// The split axis used for each split (0 = X, 1 = Y, 2 = Z).
	StaticArray<Index,3> splitAxis;
	
	// The number of primitives in a child node (leaf or not).
	StaticArray<Index,4> numChildPrimitives;
	
	// The 4 volumes of the child nodes.
	StaticArray<AABB3f,4> volumes;
	
	//***************************************************************************
	// Partition the set of primitives into two sets.
	
	PrimitiveAABB* const primitiveAABBStart = primitiveAABBs + start;
	Size numLesserPrimitives = 0;
	
	partitionPrimitivesSAH( primitiveAABBStart, numPrimitives,
						splitBins, numSplitBins,
						splitAxis[0], numLesserPrimitives, volumes[0], volumes[2] );
	
	// Compute the number of primitives greater than the split plane along the split axis.
	Size numGreaterPrimitives = numPrimitives - numLesserPrimitives;
	
	//***************************************************************************
	// Partition the primitive subsets into four sets based on the next two splitting planes.
	
	// If the number of primitives on this side of the first partition is less than the max number of
	// primitives per leaf, put all the primitives in the first child.
	if ( numLesserPrimitives < maxNumPrimitivesPerLeaf )
	{
		numChildPrimitives[0] = numLesserPrimitives;
		numChildPrimitives[1] = 0;
		volumes[0] = computeAABBForPrimitives( primitiveAABBStart, numLesserPrimitives );
	}
	else
	{
		partitionPrimitivesSAH( primitiveAABBStart, numLesserPrimitives,
							splitBins, numSplitBins,
							splitAxis[1], numChildPrimitives[0], volumes[0], volumes[1] );
	}
	
	// If the number of primitives on this side of the first partition is less than the max number of
	// primitives per leaf, put all the primitives in the first child.
	if ( numGreaterPrimitives < maxNumPrimitivesPerLeaf )
	{
		numChildPrimitives[2] = numGreaterPrimitives;
		numChildPrimitives[3] = 0;
		volumes[2] = computeAABBForPrimitives( primitiveAABBStart + numLesserPrimitives, numGreaterPrimitives );
	}
	else
	{
		partitionPrimitivesSAH( primitiveAABBStart + numLesserPrimitives, numGreaterPrimitives,
							splitBins, numSplitBins,
							splitAxis[2], numChildPrimitives[2], volumes[2], volumes[3] );
	}
	
	// Compute the number of primitives greater than the split plane along the split axis.
	numChildPrimitives[1] = numLesserPrimitives - numChildPrimitives[0];
	numChildPrimitives[3] = numGreaterPrimitives - numChildPrimitives[2];
	
	//***************************************************************************
	// Determine for each child whether to create a leaf node or an inner node.
	
	// The 4 indices of either the location in the primitive list of a leaf's primitives,
	// or the relative offset of the child node from the parent.
	StaticArray<Index,4> indices;
	
	//***************************************************************************
	// Determine the type and attributes for each node.
	
	// Keep track of the total number of nodes in the subtree.
	Size numTreeNodes = 1;
	Size primitiveStartIndex = start;
	
	for ( Index i = 0; i < 4; i++ )
	{
		if ( numChildPrimitives[i] <= maxNumPrimitivesPerLeaf )
		{
			// This child is a leaf node.
			new (node + numTreeNodes) Node( primitiveStartIndex, numChildPrimitives[i] );
			indices[i] = numTreeNodes;
			
			numTreeNodes++;
		}
		else
		{
			// This child is an inner node, construct it recursively.
			Size numChildNodes = buildTreeRecursive( node + numTreeNodes, primitiveAABBs,
													primitiveStartIndex, numChildPrimitives[i],
													splitBins, numSplitBins, maxNumPrimitivesPerLeaf,
													depth + 1, maxDepth );
			
			// The relative index of this child from the parent node.
			indices[i] = numTreeNodes;
			
			numTreeNodes += numChildNodes;
		}
		
		primitiveStartIndex += numChildPrimitives[i];
	}
	
	//***************************************************************************
	// Create the node.
	
	new (node) Node( indices[0], indices[1], indices[2], indices[3], volumes );
	
	// Update the maximum tree depth.
	if ( depth > maxDepth )
		maxDepth = depth;
	
	// Return the number of nodes in this subtree.
	return numTreeNodes;
}




//##########################################################################################
//##########################################################################################
//############		
//############		Surface Area Heuristic Object Partition Method
//############		
//##########################################################################################
//##########################################################################################




void AABBTree4:: partitionPrimitivesSAH( PrimitiveAABB* primitiveAABBs, Size numPrimitives,
											SplitBin* splitBins, Size numSplitBins,
											Index& splitAxis, Size& numLesserPrimitives,
											AABB3f& lesserVolume, AABB3f& greaterVolume )
{
	// If there are no primitives to partition, return immediately.
	if ( numPrimitives < 2 )
	{
		splitAxis = 0;
		numLesserPrimitives = numPrimitives;
		lesserVolume = computeAABBForPrimitives( primitiveAABBs, numPrimitives );
		return;
	}
	
	//**************************************************************************************
	// Compute the AABB of the primitive centroids.
	
	// We use the centroids as the 'keys' in splitting primitives.
	const AABB3f centroidAABB = computeAABBForPrimitiveCentroids( primitiveAABBs, numPrimitives );
	const Vector3f aabbDimension = centroidAABB.max - centroidAABB.min;
	
	//**************************************************************************************
	// Initialize the split bins.
	
	const Size numSplitCandidates = numSplitBins - 1;
	
	const Float binningConstant1 = Float(numSplitBins)*(Float(1) - Float(0.00001));
	Float minSplitCost = math::max<Float>();
	Float minSplitPlane = 0;
	math::SIMDScalar<float,4> lesserMin;
	math::SIMDScalar<float,4> lesserMax;
	math::SIMDScalar<float,4> greaterMin;
	math::SIMDScalar<float,4> greaterMax;
	numLesserPrimitives = 0;
	
	splitAxis = 0;
	
	for ( Index axis = 0; axis < 3; axis++ )
	{
		// Compute some constants that are valid for all bins/primitives.
		const Float binningConstant = binningConstant1 / aabbDimension[axis];
		const Float binWidth = aabbDimension[axis] / Float(numSplitBins);
		const Float binsStart = centroidAABB.min[axis];
		
		// Initialize the split bins to their starting values.
		for ( Index i = 0; i < numSplitBins; i++ )
			new (splitBins + i) SplitBin();
		
		//**************************************************************************************
		// For each primitive, determine which bin it overlaps and increase that bin's counter.
		
		for ( Index i = 0; i < numPrimitives; i++ )
		{
			const PrimitiveAABB& t = primitiveAABBs[i];
			
			Index binIndex = (Index)(binningConstant*(t.centroid[axis] - binsStart));
			SplitBin& bin = splitBins[binIndex];
			
			// Update the number of primitives that this bin contains, as well as the AABB for those primitives.
			bin.numPrimitives++;
			bin.min = math::min( bin.min, t.min );
			bin.max = math::max( bin.max, t.max );
		}
		
		//**************************************************************************************
		// Find the split plane with the smallest SAH cost.
		
		Size numLeftPrimitives = 0;
		math::SIMDScalar<float,4> leftMin( math::max<float>() );
		math::SIMDScalar<float,4> leftMax( math::min<float>() );
		
		for ( Index i = 0; i < numSplitCandidates; i++ )
		{
			// Since the left candidate is only growing, we can incrementally construct the AABB for this side.
			// Incrementally enlarge the bounding box for this side, and compute the number of primitives
			// on this side of the split.
			{
				SplitBin& bin = splitBins[i];
				numLeftPrimitives += bin.numPrimitives;
				leftMin = math::min( leftMin, bin.min );
				leftMax = math::max( leftMax, bin.max );
			}
			
			Size numRightPrimitives = 0;
			math::SIMDScalar<float,4> rightMin( math::max<float>() );
			math::SIMDScalar<float,4> rightMax( math::min<float>() );
			
			// Compute the bounding box for this side, and compute the number of primitives
			// on this side of the split.
			for ( Index j = i + 1; j < numSplitBins; j++ )
			{
				SplitBin& bin = splitBins[j];
				numRightPrimitives += bin.numPrimitives;
				rightMin = math::min( rightMin, bin.min );
				rightMax = math::max( rightMax, bin.max );
			}
			
			// Compute the cost for this split candidate.
			Float splitCost = Float(numLeftPrimitives)*getAABBSurfaceArea( leftMin, leftMax ) + 
							Float(numRightPrimitives)*getAABBSurfaceArea( rightMin, rightMax );
			
			// If the split cost is the lowest so far, use it as the new minimum split.
			if ( splitCost <= minSplitCost )
			{
				minSplitCost = splitCost;
				minSplitPlane = binsStart + binWidth*Float(i + 1);
				
				// Save the bounding boxes for this split candidate.
				lesserMin = leftMin;
				lesserMax = leftMax;
				greaterMin = rightMin;
				greaterMax = rightMax;
				
				// Save the number of primitives to the left of the split.
				numLesserPrimitives = numLeftPrimitives;
				
				// Save the axis of the minimum cost split candidate.
				splitAxis = axis;
			}
		}
	}
	
	//**************************************************************************************
	
	// If the split was unsuccessful, try a median split which is guaranteed to split the primitives.
	if ( numLesserPrimitives == 0 || numLesserPrimitives == numPrimitives )
	{
		// Choose to split along the axis with the largest extent.
		Size splitAxis = aabbDimension[0] > aabbDimension[1] ? 
						aabbDimension[0] > aabbDimension[2] ? 0 : 2 :
						aabbDimension[1] > aabbDimension[2] ? 1 : 2;
		
		// Use a median-based partition to split the primitives.
		partitionPrimitivesMedian( primitiveAABBs, numPrimitives, splitAxis, numLesserPrimitives, lesserVolume, greaterVolume );
		
		return;
	}
	
	//**************************************************************************************
	// Partition the primitives into two sets based on the minimal cost split plane.
	
	Index left = 0;
	Index right = numPrimitives - 1;
	
	while ( left < right )
	{
		// Move right while primitive < split plane.
		while ( primitiveAABBs[left].centroid[splitAxis] <= minSplitPlane && left < right )
			left++;
		
		// Move left while primitive > split plane.
		while ( primitiveAABBs[right].centroid[splitAxis] > minSplitPlane && left < right )
			right--;
		
		if ( left < right )
		{
			// Swap the primitives because they are out of order.
			const PrimitiveAABB temp = primitiveAABBs[left];
			primitiveAABBs[left] = primitiveAABBs[right];
			primitiveAABBs[right] = temp;
		}
	}
	
	// Set the number of primitives that are to the left of the split plane.
	lesserVolume = AABB3f( lesserMin[0], lesserMax[0], lesserMin[1], lesserMax[1], lesserMin[2], lesserMax[2] );
	greaterVolume = AABB3f( greaterMin[0], greaterMax[0], greaterMin[1], greaterMax[1], greaterMin[2], greaterMax[2] );
}




//##########################################################################################
//##########################################################################################
//############		
//############		Median Object Partition Method
//############		
//##########################################################################################
//##########################################################################################




void AABBTree4:: partitionPrimitivesMedian( PrimitiveAABB* primitiveAABBs, Size numPrimitives,
												Index splitAxis, Size& numLesserPrimitives,
												AABB3f& lesserVolume, AABB3f& greaterVolume )
{
	if ( numPrimitives == 2 )
	{
		numLesserPrimitives = 1;
		lesserVolume = computeAABBForPrimitives( primitiveAABBs, 1 );
		greaterVolume = computeAABBForPrimitives( primitiveAABBs + 1, 1 );
		return;
	}
	
	Index first = 0;
	Index last = numPrimitives - 1;
	Index middle = (first + last)/2;
	
	while ( 1 )
	{
		Index mid = first;
		const math::SIMDScalar<float,4>& key = primitiveAABBs[mid].centroid;
		
		for ( Index j = first + 1; j <= last; j ++)
		{
			if ( primitiveAABBs[j].centroid[splitAxis] > key[splitAxis] )
			{
				mid++;
				
				// interchange values.
				const PrimitiveAABB temp = primitiveAABBs[mid];
				primitiveAABBs[mid] = primitiveAABBs[j];
				primitiveAABBs[j] = temp;
			}
		}
		
		// interchange the first and mid value.
		const PrimitiveAABB temp = primitiveAABBs[mid];
		primitiveAABBs[mid] = primitiveAABBs[first];
		primitiveAABBs[first] = temp;
		
		if ( mid + 1 == middle )
			break;
		
		if ( mid + 1 > middle )
			last = mid - 1;
		else
			first = mid + 1;
	}
	
	numLesserPrimitives = numPrimitives / 2;
	
	lesserVolume = computeAABBForPrimitives( primitiveAABBs, numLesserPrimitives );
	greaterVolume = computeAABBForPrimitives( primitiveAABBs + numLesserPrimitives, numPrimitives - numLesserPrimitives );
}




//##########################################################################################
//##########################################################################################
//############		
//############		Generic Tree Refit Method
//############		
//##########################################################################################
//##########################################################################################




AABB3f AABBTree4:: refitTreeGeneric( Node* node )
{
	if ( node->isLeaf() )
	{
		if ( node->getPrimitiveCount() == 0 )
			return AABB3f();
		
		// Compute the bounding box of this leaf's primitives.
		const Index* primitive = (const Index*)primitiveData + node->getPrimitiveOffset();
		const Size primitiveCount = node->getPrimitiveCount();
		
		AABB3f result = primitiveSet->getAABB( primitive[0] );
		
		for ( Index i = 1; i < primitiveCount; i++ )
			result.enlargeFor( primitiveSet->getAABB( primitive[i] ) );
		
		return result;
	}
	else
	{
		AABB3f result;
		
		// Resursively find the new bounding box for the children of this node.
		for ( Index i = 0; i < 4; i++ )
		{
			AABB3f childAABB = refitTreeGeneric( node->getChild(i) );
			
			// Store the bounding box for the child in this node.
			node->setChildAABB( i, childAABB );
			
			// Find the bounding box containing all children.
			if ( i == 0 )
				result = childAABB;
			else
				result.enlargeFor( childAABB );
		}
		
		return result;
	}
}




//##########################################################################################
//##########################################################################################
//############		
//############		Triangle Tree Refit Method
//############		
//##########################################################################################
//##########################################################################################




AABB3f AABBTree4:: refitTreeTriangles( Node* node )
{
	if ( node->isLeaf() )
	{
		if ( node->getPrimitiveCount() == 0 )
			return AABB3f();
		
		// Compute the bounding box of this leaf's primitives.
		CachedTriangle* triangle = (CachedTriangle*)primitiveData + node->getPrimitiveOffset();
		const Size primitiveCount = node->getPrimitiveCount();
		
		AABB3f result = primitiveSet->getAABB( triangle[0].indices[0] );
		Vector3f v0, v1, v2;
		
		for ( Index i = 0; i < primitiveCount; i++ )
		{
			for ( Index j = 0; j < 4; j++ )
			{
				result.enlargeFor( primitiveSet->getAABB( triangle[i].indices[j] ) );
				
				// Update cached triangles.
				primitiveSet->getTriangle( triangle[i].indices[j], v0, v1, v2 );
				Vector3f e1 = v1 - v0;
				Vector3f e2 = v2 - v0;
				
				triangle[i].v0.x[j] = v0.x;
				triangle[i].v0.y[j] = v0.y;
				triangle[i].v0.z[j] = v0.z;
				triangle[i].e1.x[j] = e1.x;
				triangle[i].e1.y[j] = e1.y;
				triangle[i].e1.z[j] = e1.z;
				triangle[i].e2.x[j] = e2.x;
				triangle[i].e2.y[j] = e2.y;
				triangle[i].e2.z[j] = e2.z;
			}
		}
		
		return result;
	}
	else
	{
		AABB3f result;
		
		// Resursively find the new bounding box for the children of this node.
		for ( Index i = 0; i < 4; i++ )
		{
			AABB3f childAABB = refitTreeGeneric( node->getChild(i) );
			
			// Store the bounding box for the child in this node.
			node->setChildAABB( i, childAABB );
			
			// Find the bounding box containing all children.
			if ( i == 0 )
				result = childAABB;
			else
				result.enlargeFor( childAABB );
		}
		
		return result;
	}
}




//##########################################################################################
//##########################################################################################
//############		
//############		Primitive Index List Building Method
//############		
//##########################################################################################
//##########################################################################################




void AABBTree4:: fillPrimitiveIndices( Index* primitiveIndices, const PrimitiveAABB* primitiveAABBs, Size numPrimitives )
{
	for ( Index i = 0; i < numPrimitives; i++ )
		primitiveIndices[i] = primitiveAABBs[i].primitiveIndex;
}




//##########################################################################################
//##########################################################################################
//############		
//############		Axis-Aligned Bound Box Calculation Methods
//############		
//##########################################################################################
//##########################################################################################




AABB3f AABBTree4:: computeAABBForPrimitives( const PrimitiveAABB* primitiveAABBs, Size numPrimitives )
{
	/// Create a bounding box with the minimum at the max float value and visce versa.
	SIMDFloat4 min( math::max<float>() );
	SIMDFloat4 max( math::min<float>() );
	
	const PrimitiveAABB* const primitiveAABBsEnd = primitiveAABBs + numPrimitives;
	
	while ( primitiveAABBs != primitiveAABBsEnd )
	{
		min = math::min( min, primitiveAABBs->min );
		max = math::max( max, primitiveAABBs->max );
		
		primitiveAABBs++;
	}
	
	return AABB3f( min[0], max[0], min[1], max[1], min[2], max[2] );
}




AABB3f AABBTree4:: computeAABBForPrimitiveCentroids( const PrimitiveAABB* primitiveAABBs, Size numPrimitives )
{
	/// Create a bounding box with the minimum at the max float value and visce versa.
	SIMDFloat4 min( math::max<float>() );
	SIMDFloat4 max( math::min<float>() );
	
	const PrimitiveAABB* const primitiveAABBsEnd = primitiveAABBs + numPrimitives;
	
	while ( primitiveAABBs != primitiveAABBsEnd )
	{
		min = math::min( min, primitiveAABBs->centroid );
		max = math::max( max, primitiveAABBs->centroid );
		
		primitiveAABBs++;
	}
	
	return AABB3f( min[0], max[0], min[1], max[1], min[2], max[2] );
}




float AABBTree4:: getAABBSurfaceArea( const SIMDFloat4& min,
										const SIMDFloat4& max )
{
	const SIMDFloat4 aabbDimension = max - min;
	
	return float(2)*(aabbDimension[0]*aabbDimension[1] +
					aabbDimension[0]*aabbDimension[2] +
					aabbDimension[1]*aabbDimension[2]);
}




//##########################################################################################
//##########################################################################################
//############		
//############		Triangle List Building Methods
//############		
//##########################################################################################
//##########################################################################################




Size AABBTree4:: getTriangleArraySize( const Node* node ) const
{
	if ( node->isLeaf() )
		return math::nextMultiple( node->getPrimitiveCount(), IndexType(4) ) >> 2;
	else
	{
		Size result = 0;
		
		for ( Index i = 0; i < 4; i++ )
			result += getTriangleArraySize( node->getChild(i) );
		
		return result;
	}
}




Size AABBTree4:: fillTriangleArray( CachedTriangle* triangles, const PrimitiveInterface* primitiveInterface,
									const PrimitiveAABB* aabbs, Node* node, Size numFilled )
{
	Size currentOutputIndex = numFilled;
	
	if ( node->isLeaf() )
	{
		Size numLeafTriangles = node->getPrimitiveCount();
		Size numTruncatedTriangles = ((numLeafTriangles >> 2) << 2);
		Size numPaddedTriangles = numTruncatedTriangles == numLeafTriangles ? 
									numTruncatedTriangles : numTruncatedTriangles + 4;
		
		// Update the per-node primitive count to reflect that 4 regular triangles = 1 cached triangle.
		Index currentOffset = node->getPrimitiveOffset();
		node->setPrimitiveOffset( currentOutputIndex );
		node->setPrimitiveCount( numPaddedTriangles >> 2 );
		
		Size numIterations = numPaddedTriangles >> 2;
		
		for ( Index k = 0; k < numIterations; k++ )
		{
			// Determine the number of triangles to go into this cached triangle, 4 or less.
			Size numRemainingTriangles = math::min( numLeafTriangles - k*4, Size(4) );
			
			Vector3f v0, v1, v2;
			SIMDVector3f simdV0;
			SIMDVector3f simdE1;
			SIMDVector3f simdE2;
			StaticArray<Index,4> indices;
			
			// Get the triangle from the primitive set.
			for ( Index t = 0; t < 4; t++ )
			{
				// If there are no more remaining triangles, use the last valid one.
				if ( t < numRemainingTriangles )
					indices[t] = aabbs[currentOffset + t].primitiveIndex;
				else
					indices[t] = aabbs[currentOffset + numRemainingTriangles - 1].primitiveIndex;
				
				primitiveInterface->getTriangle( indices[t], v0, v1, v2 );
				Vector3f e1 = v1 - v0;
				Vector3f e2 = v2 - v0;
				
				// Convert to SIMD layout.
				simdV0.x[t] = v0.x;
				simdV0.y[t] = v0.y;
				simdV0.z[t] = v0.z;
				simdE1.x[t] = e1.x;
				simdE1.y[t] = e1.y;
				simdE1.z[t] = e1.z;
				simdE2.x[t] = e2.x;
				simdE2.y[t] = e2.y;
				simdE2.z[t] = e2.z;
			}
			
			// Create the new triangle.
			new (triangles + currentOutputIndex) CachedTriangle( simdV0, simdE1, simdE2, indices );
			
			currentOffset += 4;
			currentOutputIndex++;
		}
	}
	else
	{
		for ( Index i = 0; i < 4; i++ )
			currentOutputIndex += fillTriangleArray( triangles, primitiveInterface, aabbs, node->getChild(i), currentOutputIndex );
	}
	
	return currentOutputIndex - numFilled;
}




//##########################################################################################
//##########################################################################################
//############		
//############		Primitive Data Copy Method
//############		
//##########################################################################################
//##########################################################################################




UByte* AABBTree4:: copyPrimitiveData( Size& newCapacity ) const
{
	switch ( cachedPrimitiveType )
	{
		case PrimitiveInterfaceType::TRIANGLES:
			return (UByte*)util::copyArrayAligned( (const CachedTriangle*)primitiveData, getTriangleArraySize(nodes), 16 );
		
		default:
			return (UByte*)util::copyArrayAligned( (const Index*)primitiveData, numPrimitives, 16 );
	}
}

Nobody said it, so I'm gonna say it.

USE __RESTRICT

Also your code has an if( determinant .... ); which sounds to me it could be converted to branchless using conditional moves. Unless the cost of executing the code that can be skipped is considerably larger than the cost of misspredicting the branch.

@Aressera

Woah, that aabb codes looks great thankyou! This if the main area now slowing mine down as it is a rather naive octree test system and still has to test a lot of triangles.

I don't suppose I could see your equivalent of a Vector4 implementation? I see you are accessing the components directly (.x, .y etc) so are you using a union? I have to access my __m128 as .m128_f32[3] (for x as it reverses the order in which it stores the floats annoyingly) which doesn't help with code portability but I was unsure of the impact of using a union to access the elements.

@Matias

That link looks interesting, I think there are a fair amount of mispredictions and loading waits to potentially happen in my code so I'll check it out :)

I don't use SIMD for my 4D vector, mostly for alignment reasons. I differentiate between a SIMD scalar (4 floats, i.e. 4 X-components) and cartesian vectors. Everything is in structures-of-arrays (SoA) format, so there shouldn't be any issues with the storage order for __m128, no need to know what is X and what is Y, etc.

In my code:

SIMDFloat4 = { union { float x[4]; __m128 v; } }

SIMDVector3 = { SIMDFloat4 x; SIMDFloat4 y; SIMDFloat4 z; }

SIMDRay3 = { SIMDVector3 origin; SIMDVector3 direction; }

This topic is closed to new replies.

Advertisement