From what I gather, the simplest approach works by instead of using the center point of each axis from the nodes AABB, you instead, for each triangle work out its projection, giving a tmin and a tmax, creating two splitting points, and then for each point, work out the number of triangles that lie on the left and the right. Once you've done this for each triangle, choose the splitting point that gives the best even number of triangles on each side. continue this process for every node.
Here is the kdtree code (with the cleanup code omitted)
#define ENCODE_VALUE_AND_DIMENSION( des, v, d ) des = (((*(uint*)&v) & (( UINT_MAX << 2 ) ) ) + d )
#define ENCODE_LEAF( des, v ) des = (((*(uint*)&v) & (( UINT_MAX << 2 ) ) ) + 3 )
#define TREE_SIZE( i ) ((1 << (i+1)) - 1)
#define DECODE_DIMENSION( x ) ( x & 3 )
#define DECODE_VALUE( x ) (*(float*)(&x))
#define ISLEAF( x ) ( DECODE_DIMENSION(x) == 3 )
#define DECODE_LEAF( x ) ( x & ( UINT_MAX << 2 ) )
#define LEFT_NODE( c, m ) ( 1 )
#define RIGHT_NODE( c, m ) ( 1 + ( TREE_SIZE( m - c ) >> 1 ) )
class KdTree
{
public:
KdTree( uint MaximumLevel, uint MinimumTriangles, const TriangleList& triangles )
: m_MaximumLevel(MaximumLevel), m_MinimumTriangles(MinimumTriangles)
{
uint size = TREE_SIZE( m_MaximumLevel );
m_pNodeArray = new Node[size];
memset( m_pNodeArray, 0, size * sizeof(Node));
for( TriangleList::size_type i = 0; i < triangles.size(); ++i )
{
m_AABB.AddPoint(triangles->v0.p);
m_AABB.AddPoint(triangles->v1.p);
m_AABB.AddPoint(triangles->v2.p);
}
m_pNodeArray->Add( m_AABB, 0, triangles, 0, m_MaximumLevel, m_MinimumTriangles );
}
bool Intersect( const Ray& queryRay, IntersectionInfo& ii )
{
float tmin, tmax, iit = ii.t; int bailout = 0;
if( RayAABBIntersection( queryRay, m_AABB, tmin, tmax ) )
{
Vec3 invRayDir = Vec3( 1.0f / queryRay.dir.x, 1.0f / queryRay.dir.y , 1.0f / queryRay.dir.z );
m_pNodeArray->Intersect( bailout, 0, m_MaximumLevel, queryRay.Split(tmin), invRayDir, tmin, tmax - tmin, ii );
return ii.t < iit;
}
return 0;
}
private:
struct Node
{
Node(){ _Data = 0; }
void Add( const AABB& aabb, uint d, const TriangleList& triangles, uint CurrentLevel, uint MaximumLevel, uint MinimumTriangles )
{
if( CurrentLevel == MaximumLevel || triangles.size() < MinimumTriangles )
{
TriangleList* pTriangles = 0;
if( !triangles.empty() )
{
pTriangles = new TriangleList( triangles.begin(), triangles.end() );
}
ENCODE_LEAF( _Data, pTriangles );
}
else
{
float splitValue = 0.5f * aabb.minima[d] + 0.5f * aabb.maxima[d];
TriangleList leftTBuffer, rightTBuffer;
leftTBuffer.reserve( triangles.size() );
rightTBuffer.reserve( triangles.size() );
for( TriangleList::size_type i = 0; i < triangles.size(); ++i )
{
if( triangles->v0.p[d] <= splitValue ||
triangles->v1.p[d] <= splitValue ||
triangles->v2.p[d] <= splitValue )
{
leftTBuffer.push_back(triangles);
}
if( triangles->v0.p[d] > splitValue ||
triangles->v1.p[d] > splitValue ||
triangles->v2.p[d] > splitValue )
{
rightTBuffer.push_back(triangles);
}
}
ENCODE_VALUE_AND_DIMENSION( _Data, splitValue, d );
(this + LEFT_NODE( CurrentLevel, MaximumLevel ) )->Add( aabb.ClipByPlane( 0, d, splitValue ), (d+1)%3, leftTBuffer, CurrentLevel+1, MaximumLevel, MinimumTriangles );
(this + RIGHT_NODE( CurrentLevel, MaximumLevel ) )->Add( aabb.ClipByPlane( 1, d, splitValue ), (d+1)%3, rightTBuffer, CurrentLevel+1, MaximumLevel, MinimumTriangles );
}
}
void IntersectLeaf( const TriangleList* pTriangles, int& bailout, const Ray& ray, const float tail, const float tmax, IntersectionInfo& ii, int& TriangleTests )
{
float lu, lv, lw, lt;
for( TriangleList::size_type i = 0; i < pTriangles->size(); ++i )
{
const Triangle* pT = (*pTriangles);
if( RayTriangleIntersectionDouble( ray, pT->v0.p, pT->v1.p, pT->v2.p, lu, lv, lw, lt ) )
{
float t = lt + tail;
if( t < ii.t )
{
ii.t = t;
if( ii.type == ANY )
{
bailout = 1;
return;
}
if( lt < tmax )
{
bailout = 1;
}
ii.pM = pT->v0.pMaterial;
ii.UV = lu*pT->v0.uv + lv*pT->v1.uv + lw*pT->v2.uv;
ii.GNormal = lu*pT->v0.n + lv*pT->v1.n + lw*pT->v2.n;
ii.SNormal = lu*pT->v0.n + lv*pT->v1.n + lw*pT->v2.n;
}
}
}
}
void Intersect( int& bailout, uint CurrentLevel, uint MaximumLevel, const Ray& ray, const Vec3& invRayDir, const float tail, const float tmax, IntersectionInfo& ii )
{
if( tail > ii.t )
return;
if( ISLEAF( _Data ) )
{
if( TriangleList* pTriangles = (TriangleList*)DECODE_LEAF( _Data ) )
IntersectLeaf( pTriangles, bailout, ray, tail, tmax, ii, TriangleTests );
}
else
{
uint splitDimension = DECODE_DIMENSION( _Data );
float splitValue = DECODE_VALUE( _Data );
Node* first = this + LEFT_NODE( CurrentLevel, MaximumLevel );
Node* second = this + RIGHT_NODE( CurrentLevel, MaximumLevel );
if( ray.origin[splitDimension] > splitValue )
{
std::swap( first, second );
}
float tplane = ( splitValue - ray.origin[splitDimension] ) * invRayDir[splitDimension];
if( tplane >= 0.0f && tplane < tmax )
{
if( !bailout ) first->Intersect( bailout, CurrentLevel+1, MaximumLevel, ray, invRayDir, tail, tplane, ii );
if( !bailout ) second->Intersect( bailout, CurrentLevel+1, MaximumLevel, ray.Split(tplane), invRayDir, tail + tplane, tmax - tplane, ii );
}
else
{
if( !bailout ) first->Intersect( bailout, CurrentLevel+1, MaximumLevel, ray, invRayDir, tail, tmax, ii );
}
}
}
uint _Data;
};
Node* m_pNodeArray;
uint m_MaximumLevel, m_MinimumTriangles;
AABB m_AABB;
};