#ifndef __BVH_TREE_H_INCLUDED
#define __BVH_TREE_H_INCLUDED

#include "Box.hxx"

class BVH
{
	struct Node
	{
		Box bbox;

		//Note: the result is stored in ray! (in ray.t and ray.hit)
		virtual void traverse(Ray &ray) = 0;

		virtual ~Node(){};
	};

	struct InnerNode : public Node
	{
		Node *leftChild, *rightChild;

		InnerNode(const Box& aBBox)
		{
			bbox = aBBox;
		}

		virtual void traverse(Ray &ray)
		{
			std::pair<float, float> leftPair  = leftChild->bbox.Intersect(ray),
			                        rightPair = rightChild->bbox.Intersect(ray);

			if (leftPair.first < rightPair.first)
			{
				// Left is near
				if (leftPair.first < FLT_MAX)
					leftChild->traverse(ray);

				if (rightPair.first < FLT_MAX && (!ray.hit || ray.t > rightPair.first))
					rightChild->traverse(ray);
			}
			else
			{
				// Right is near
				if (rightPair.first < FLT_MAX)
					rightChild->traverse(ray);

				if (leftPair.first < FLT_MAX && (!ray.hit || ray.t > leftPair.first))
					leftChild->traverse(ray);
			}
		}

		virtual ~InnerNode() {}
	};

	struct LeafNode : public Node
	{
		vector<Primitive *> primitive;

		LeafNode(const Box& aBBox, vector<Primitive *> &prim)
		{
			bbox = aBBox;
			primitive = prim;
		}

		virtual void traverse(Ray &ray)
		{
			for (vector<Primitive *>::iterator it = primitive.begin(); it != primitive.end(); ++it)
				(*it)->Intersect(ray);
		}

		virtual ~LeafNode() {}
	};

public:
	int maxDepth, minTri;

	// In this exercise, we use the median instead of halving the distance in each step, because
	// this balances the tree further, which increases rendering speed (as the bounding boxes are
	// shrinked to minimal size, which eliminates most of the problems described in lecture)
	Node *BuildTree(Box &bounds, vector<Primitive *> prim, int depth = 0)
	{
		if (static_cast<int>(prim.size()) <= minTri || depth >= maxDepth)
		{
			// Generate a voxel
			return new LeafNode(bounds, prim);
		}
		else
		{
			// Generate inner node
			Vec3f sum(0);
			Vec3f minimum(Infinity),
			      maximum(-Infinity),
			      elongation;
			Box   box;

			InnerNode *thisnode = new InnerNode(bounds);

			for (vector<Primitive *>::iterator it = prim.begin(); it != prim.end(); ++it)
			{
				box  = (*it)->CalcBounds();

				// Median calculation
				sum += box.max + box.min;

				// Investgation of greatest elongation
				minimum.SetMin(box.min);
				maximum.SetMax(box.max);
			}

			Vec3f median = sum / static_cast<float>(prim.size());

			vector<Primitive *> leftPrim,
					    rightPrim;

			Box leftBox, rightBox;

			// Determine elongation
			elongation = maximum - minimum;

			if (elongation.x > elongation.y && elongation.x > elongation.z)
			{
				// Divide in x
				for (vector<Primitive *>::iterator it = prim.begin(); it != prim.end(); ++it)
				{
					box  = (*it)->CalcBounds();
					if (box.max.x + box.min.x < median.x)
					{
						leftPrim.push_back(*it);
						leftBox.Extend(box.max);
						leftBox.Extend(box.min);
					}
					else
					{
						rightPrim.push_back(*it);
						rightBox.Extend(box.max);
						rightBox.Extend(box.min);
					}
				}
			}
			else if (elongation.y > elongation.x && elongation.y > elongation.z)
			{
				// Divide in y
				for (vector<Primitive *>::iterator it = prim.begin(); it != prim.end(); ++it)
				{
					box  = (*it)->CalcBounds();
					if (box.max.y + box.min.y < median.y)
					{
						leftPrim.push_back(*it);
						leftBox.Extend(box.max);
						leftBox.Extend(box.min);
					}
					else
					{
						rightPrim.push_back(*it);
						rightBox.Extend(box.max);
						rightBox.Extend(box.min);
					}
				}
			}
			else
			{
				// Divide in z
				for (vector<Primitive *>::iterator it = prim.begin(); it != prim.end(); ++it)
				{
					box  = (*it)->CalcBounds();
					if (box.max.z + box.min.z < median.z)
					{
						leftPrim.push_back(*it);
						leftBox.Extend(box.max);
						leftBox.Extend(box.min);
					}
					else
					{
						rightPrim.push_back(*it);
						rightBox.Extend(box.max);
						rightBox.Extend(box.min);
					}
				}
			}
			thisnode->leftChild  = BuildTree(leftBox,  leftPrim,  depth + 1);
			thisnode->rightChild = BuildTree(rightBox, rightPrim, depth + 1);
			return thisnode;
		}
	}

	Node *root;

	BVH(Box topBox, vector<Primitive *> primitives)
	{
		maxDepth = 30;
		minTri   = 3;
		root     = BuildTree(topBox, primitives, 0);
	}

	virtual ~BVH(){};

	bool Intersect(Ray &ray) 
	{
		ray.hit = NULL;
		ray.t   = Infinity;

		root->traverse(ray);

		return ray.hit != NULL;
	}

};

#endif
