#ifndef SCHLICKSHADER_HXX
#define SCHLICKSHADER_HXX

#include "Shader.hxx"
#include "Material.hxx"

#include "defines.h"
 
class SchlickShader : public Shader
{
public:
	Vec3f      ambient,           // ambient radiance 
	           light_ray_org_std; // template for light_ray.org, to calculate it only once
	float      p12,               // square of the isotropy factor for primary surface
	           p22;               // square of the isotropy factor for primary surface

	SchlickShader(Scene *scene, Material *material)
		: Shader(scene, material)
	{
		// Ambient illumination part
		ambient = material->ambient * Product(material->color, La);

		// Initialization of reflection and refraction
		if (material->mirror > 1.0 || material->transparency > 1.0 || material->mirror + material->transparency > 1.0)
			cerr << "WARNING: Material '" << material << "' has wrong reflection/refraction percentages (> 1.0)!" << endl;

		// Precomputation of p^2
		p12 = material->p1 * material->p1;
		p22 = material->p2 * material->p2;
	};

	virtual ~SchlickShader()
	{};

	virtual Vec3f MirroredColor(const Ray &origray, const Vec3f &normal)
	{
		Ray ray;
		ray.org = origray.org + origray.t * origray.dir;
		ray.t   = Infinity;
		ray.hit = NULL;

		// Project the ray direction onto the normal vector...
		Vec3f projected = Dot(normal, -origray.dir) * normal;

		// ...and mirror away
		ray.dir += 2 * projected;
		Normalize(ray.dir);

		return scene->RayTrace(ray);
	}

	virtual Vec3f TransparentColor(const Ray &origray, const Vec3f &normal, const bool &into)
	{
		Ray ray;
		ray.org = origray.org + origray.t * origray.dir;
		ray.t   = Infinity;
		ray.hit = NULL;

		// ray.dir is preliminary used to store the direction orthogonal to the normal
		// and expressing the remaining parts of the original raydirection
                if (fabs(Dot(origray.dir, normal)) < Epsilon)
                        ray.dir = Vec3f(0.0f);
		else
		{
			ray.dir = -Cross(Cross(origray.dir, normal), normal);
			Normalize(ray.dir);
		}

		if (into)
			ray.refrac_index = origray.refrac_index + (material->refrac_index - 1.0f);
		else
			ray.refrac_index = origray.refrac_index - (material->refrac_index - 1.0f);

		float sinb = origray.refrac_index / ray.refrac_index * Dot(origray.dir, ray.dir);
		ray.dir = sinb * ray.dir - sqrt(1 - sinb * sinb) * normal;

		return scene->RayTrace(ray);
	}

	virtual Vec3f Shade(Ray &ray)
	{
		Vec3f Lr(0,0,0),
		      Ll,
		      H,
		      Hb,
		      T,
		      N = ray.hit->GetNormal(ray),
		      pixelColor;
		Ray   light_ray;
		float Lr_fac,            // Factor of Lr, for time improvement
		      t,                 // +
		      u,                 //  |
		      v,                 //   > cosines of various angles between normal vectors
		      x,                 //  |
		      w,                 // +
		      s1, s2,            // spectral factors
		      d1, d2,            // directional factors
		      a1, a2,            // azimuth angle
		      z1, z2,            // zenith angle
		      zd1, zd2,          // denominator of zenith angle (precomp)
		      gv1, gv2,          // smith factor v
		      gx1, gx2;          // smith factor v'
		bool  into;

		into = true;
		if (Dot(N,ray.dir) > 0)
		{
			N = -N;
			into = false;
		}

		light_ray_org_std = ray.org + ray.t * ray.dir;

		if (material->bumpmapper)
			material->bumpmapper->BumpMap(N, light_ray_org_std);

		for (std::vector<Light*>::iterator it = scene->lights.begin(); it != scene->lights.end(); ++it)
		{
			light_ray.org  = light_ray_org_std;

#ifdef AREALIGHT
			for (int i = 0; i < NUM_AREA_SAMPLES; ++i)
#endif
			{
				(*it)->Illuminate(light_ray, Ll);
				Lr_fac = 0;

				Vec3f occolor = (*it)->intensity.NormTo(1.0);

				if (!scene->Occluded(light_ray, occolor))
				{
					// Get the halfway vector in between incoming and outgoing ray
					H = light_ray.dir - ray.dir;
					Normalize(H);

					Hb = H - Dot(H,N) * N;
					Normalize(Hb);

					T = ray.dir - Dot(-ray.dir, N) * N;
					Normalize(T);

					// Calculation of various scaling factors
					t = fabs(Dot(H,             N));
					u = fabs(Dot(light_ray.dir, H));
					v = fabs(Dot(light_ray.dir, N));
					x = fabs(Dot(-ray.dir,      N));
					w = fabs(Dot(Hb,            T));

					// Precomputation of w^2
					w = w * w;

					// Schlick approximation of Cook Torrance illumination model
					// Single layered part
					s1  = material->c1 + (1 - material->c1) * powf(1 - u, 5);
					zd1 = 1 + (material->r1 - 1) * t * t;
					zd1 = zd1 * zd1;
					z1  = material->r1 / zd1;
					a1  = sqrt(material->p1 / (p12 - p12 * w + w));
					gv1 = v / (material->r1 - material->r1 * v + v);
					gx1 = x / (material->r1 - material->r1 * x + x);
					d1  = (((1 - gv1 * gx1) * v * x) + gv1 * gx1 * z1 / 4) * a1 / M_PI;
					Lr_fac += s1 * d1;

					// Optional double layered part
					if (material->c2 != UNDEF)
					{
						s2  = material->c2 + (1 - material->c2) * powf(1 - u, 5);
						zd2 = 1 + (material->r2 - 1) * t * t;
						zd2 = zd2 * zd2;
						z2  = material->r2 / zd2;
						a2  = sqrt(material->p2 / (p22 - p22 * w + w));
						gv2 = v / (material->r2 - material->r2 * v + v);
						gx2 = x / (material->r2 - material->r2 * x + x);
						d2  = (((1 - gv2 * gx2) * v * x) + gv2 * gx2 * z2 / 4) * a2 / M_PI;

						Lr_fac +=  (1 - s1) * s2 * d2;
					}

					pixelColor = Product(material->color, Ll);

					if (material->texture)
					{
						float u = 0,
						      v = 0;

						ray.hit->GetUV(ray, u, v);
						v = 1.0f - v;

						pixelColor = Product(pixelColor, material->texture->GetTexel(u,v));
					}

					pixelColor = Product(pixelColor, occolor);

					Lr += Lr_fac * pixelColor;
				}
			}
		}

#ifdef AREALIGHT
		Lr /= NUM_AREA_SAMPLES;
#endif

		if (material->mirror != UNDEF && material->transparency != UNDEF)
		{
			Lr *= (1.0f - material->mirror - material->transparency);

			if (material->refrac_filter != UNDEF)
			{
				Vec3f ref = TransparentColor(ray, N, into) *  material->transparency;
				Lr += (material->refrac_filter * (ref.x + ref.y + ref.z) / 3.0f) * material->color
				   +  (1.0f - material->refrac_filter) * ref;
			}
			else
				Lr += material->transparency * TransparentColor(ray, N, into);

			if (material->mirror_filter != UNDEF)
			{
				Vec3f ref = MirroredColor(ray, N) *  material->mirror;
				Lr += (material->mirror_filter * (ref.x + ref.y + ref.z) / 3.0f) * material->color
				   +  (1.0f - material->mirror_filter) * ref;
			}
			else
				Lr += material->mirror * MirroredColor(ray, N);
		}
		else if (material->mirror != UNDEF)
		{
			Lr *= (1.0f - material->mirror);

			if (material->mirror_filter != UNDEF)
			{
				Vec3f ref = MirroredColor(ray, N) *  material->mirror;
				Lr += (material->mirror_filter * (ref.x + ref.y + ref.z) / 3.0f) * material->color
				   +  (1.0f - material->mirror_filter) * ref;
			}
			else
				Lr += material->mirror * MirroredColor(ray, N);
		}
		else if (material->transparency != UNDEF)
		{
			Lr *= (1.0f - material->transparency);

			if (material->refrac_filter != UNDEF)
			{
				Vec3f ref = TransparentColor(ray, N, into) *  material->transparency;
				Lr += (material->refrac_filter * (ref.x + ref.y + ref.z) / 3.0f) * material->color
				   +  (1.0f - material->refrac_filter) * ref;
			}
			else
				Lr += material->transparency * TransparentColor(ray, N, into);
		}

		Lr += ambient;
		return Lr;
	};
};

#endif
