#ifndef __BVH_TREE_H_INCLUDED
#define __BVH_TREE_H_INCLUDED

#include "Box.hxx"

class BVH
{
	struct Node
	{
		Box bbox;
		//Note: the result is stored in ray! (in ray.t and ray.hit)
		virtual void traverse(Ray &ray) = 0;
	};

	struct InnerNode : public Node
	{
		Node *leftChild, *rightChild;

		InnerNode(const Box& aBBox)
		{
			bbox = aBBox;
		}

		virtual void traverse(Ray &ray)
		{
			float leftMinInt = leftChild->bbox.Intersect(ray).first;
			float rightMinInt = rightChild->bbox.Intersect(ray).first;

			if(leftMinInt < rightMinInt)
			{
				leftChild->traverse(ray);
				if(ray.t < rightMinInt)
					return;
				rightChild->traverse(ray);
			}
			else if(rightMinInt < FLT_MAX)
			{
				rightChild->traverse(ray);
				if(ray.t < leftMinInt)
					return;
				leftChild->traverse(ray);
			}
		}
		
		virtual ~InnerNode() {};
	};

	struct LeafNode : public Node
	{
		vector<Primitive *> primitive;

		LeafNode(const Box& aBBox, vector<Primitive *> &prim)
		{
			bbox = aBBox;
			primitive = prim;
		}

		virtual void traverse(Ray &ray)
		{
			for (int i=0; i<(int)primitive.size(); i++)
				primitive[i]->Intersect(ray);
		}

		virtual ~LeafNode() {};
	};

public:
	int maxDepth, minTri;

	Node *BuildTree(Box &bounds, vector<Primitive *> prim, int depth)
	{
		if (depth > maxDepth || (int)prim.size() <= minTri) 
		{
			// could do some optimizations here..
			return new LeafNode(bounds, prim);
		}

		InnerNode *node = new InnerNode(bounds);

		Vec3f diam = bounds.max - bounds.min;
		int dim = diam.MaxDim();
		Vec3f center = (bounds.max + bounds.min) * 0.5;

		Box lBounds, rBounds;
		vector<Primitive *> lPrim, rPrim;

		for (int i = 0; i < (int)prim.size(); i++) 
		{
			Box primBox = prim[i]->CalcBounds();
			Vec3f primCenter = (primBox.min + primBox.max) * 0.5;
			if(primCenter[dim] <= center[dim])
			{
				lPrim.push_back(prim[i]);
				lBounds.Extend(primBox);
			}
			else
			{
				rPrim.push_back(prim[i]);
				rBounds.Extend(primBox);
			}
		}

		node->leftChild = BuildTree(lBounds,lPrim,depth+1);
		node->rightChild = BuildTree(rBounds,rPrim,depth+1);

		return node;
	}

	Node *root;

	BVH(Box topBox, vector<Primitive *> primitives)
	{
		maxDepth = 30;
		minTri = 3;
		root = NULL;
		root = BuildTree(topBox, primitives, 0);
	}

	bool Intersect(Ray &ray) 
	{
		root->traverse(ray);

		return ray.hit != NULL;
	}

};

#endif
