#include "terrain.h"

cTerrain::cTerrain(int N, float power, float scale, bool geometry) : N(N), Nplus1(N+1), power(power), scale(scale), geometry(geometry), indices_count(0) {
	int index, index2;

	// prep our vertices
	vertices = new vertex_terrain[Nplus1 * Nplus1];
	for (int j = 0; j < Nplus1; j++) {
		for (int i = 0; i < Nplus1; i++) {
			index = i + j * Nplus1;
			vertices[index].x  = i;
			vertices[index].y  = 0;
			vertices[index].z  = j;
			vertices[index].nx = 0;
			vertices[index].ny = 1;
			vertices[index].nz = 0;
			vertices[index].tx = (float)i/(float)(N);
			vertices[index].ty = 0;
			vertices[index].tz = (float)j/(float)(N);
			vertices[index].cx = 0.0;
			vertices[index].cy = 0.0;
			vertices[index].cz = 0.0;
		}
	}

	// prep our indices
	indices = new unsigned int[N * N * 6];
	for (int j = 0; j < N; j++) {
		for (int i = 0; i < N; i++) {
			index = i + j * Nplus1;
			if (geometry) {
				indices[indices_count++] = index;
				indices[indices_count++] = index + 1;
				indices[indices_count++] = index;
				indices[indices_count++] = index + Nplus1;
				indices[indices_count++] = index + 1;
				indices[indices_count++] = index + Nplus1;
			} else {
				indices[indices_count++] = index;
				indices[indices_count++] = index + Nplus1;
				indices[indices_count++] = index + 1;
				indices[indices_count++] = index + 1;
				indices[indices_count++] = index + Nplus1;
				indices[indices_count++] = index + Nplus1 + 1;
			}
		}
	}


	complex *c = new complex[N*N];
	cFFT fft(N);

	// generate noise in the spatial domain
	for (int i = 0; i < N * N; i++) {
		c[i] = gaussianRandomVariable();//complex(rand()%2048/2048.0, rand()%2048/2048.0);
	}

	// transform to the frequency domain
	for (int j = 0; j < N; j++) fft.fft_(c, c, 1, j * N);
	for (int i = 0; i < N; i++) fft.fft_(c, c, N, i);

	// apply 1/(f^p) filter
	double f;
	for (int j = 0; j < N; j++) {
		for (int i = 0; i < N; i++) {
			index = i + j * N;
			f = sqrt((i-N/2)/float(N) * (i-N/2)/float(N) + (j-N/2)/float(N) * (j-N/2)/float(N));
			f = f < 1.0/N ? 1.0/N : f;
			//f = sqrt((i-N/2) * (i-N/2) + (j-N/2) * (j-N/2));
			//f = f < 1 ? 1.0 : f;
			c[index] = c[index]*(1.0f/ pow(f, power));
		}
	}

	// transform back to spatial domain
	for (int j = 0; j < N; j++) {
		fft.fft(c, c, 1, j * N);
		for (int i = 0; i < N; i++) {
			c[i + j * N] = c[i + j * N] * (1.0f/N);
		}
	}
	for (int i = 0; i < N; i++) {
		fft.fft(c, c, N, i);
		for (int j = 0; j < N; j++) {
			c[i + j * N] = c[i + j * N] * (1.0f/N);
		}
	}

	// below we build the height map and evaluate the minimum and maximum height for applying the color gradient
	double min=0.0;
	double max=0.0;
	int signs[] = { 1, -1 };
	height_map = new float[N*N];
	for (int j = 0; j < N; j++) {
		for (int i = 0; i < N; i++) {
			index = i + j * N;
			height_map[index] = scale * c[index].a * signs[(i+j)%2];
			min = height_map[index] < min ? height_map[index] : min;
			max = height_map[index] > max ? height_map[index] : max;
		}
	}

	delete [] c;

	// below we finish preparing our vertices by applying the height map and color value in addition to evaluating the normal vectors
	for (int j = 0; j < Nplus1; j++) {
		for (int i = 0; i < Nplus1; i++) {
			index  = i + j * Nplus1;

			index2 = (i == N && j == N) ? 0 : ((i == N) ? 0 + j * N : ((j == N) ? i : i + j * N));
			vertices[index].y = height_map[index2];
			vector3 c = color(height_map[index2], min, max);

			vertices[index].cx = c.x;
			vertices[index].cy = c.y;
			vertices[index].cz = c.z;
		}
	}
	// periodicity of the fourier transform allows us to tile our terrain.. below is crude implementation to wrap our normals.. assumes vertices have integer coordinates
	for (int j = 0; j < Nplus1 - 1; j++) {
		for (int i = 0; i < Nplus1 - 1; i++) {
			index = i + j * Nplus1;
			if (i == 0 && j == 0) {
				vector3 v0 = vector3(vertices[index].x,              vertices[index].y,              vertices[index].z);
				vector3 v1 = vector3(vertices[index+Nplus1].x,       vertices[index+Nplus1].y,       vertices[index+Nplus1].z);
				vector3 v2 = vector3(vertices[index+1].x,            vertices[index+1].y,            vertices[index+1].z);
				vector3 v3 = vector3(vertices[index+(N-1)*Nplus1].x, vertices[index+(N-1)*Nplus1].y, vertices[index+(N-1)*Nplus1].z-N);
				vector3 v4 = vector3(vertices[index+N-1].x-N,        vertices[index+N-1].y,          vertices[index+N-1].z);
				vector3 n0 = ((v1-v0).cross(v2-v0)).unit();
				vector3 n1 = ((v2-v0).cross(v3-v0)).unit();
				vector3 n2 = ((v3-v0).cross(v4-v0)).unit();
				vector3 n3 = ((v4-v0).cross(v1-v0)).unit();
				vector3 n = (n0+n1+n2+n3).unit();
				vertices[index].nx = vertices[index+N].nx = vertices[index+N*Nplus1].nx = vertices[index + N + N * Nplus1].nx = n.x;
				vertices[index].ny = vertices[index+N].ny = vertices[index+N*Nplus1].ny = vertices[index + N + N * Nplus1].ny = n.y;
				vertices[index].nz = vertices[index+N].nz = vertices[index+N*Nplus1].nz = vertices[index + N + N * Nplus1].nz = n.z;
			} else if (j == 0) {
				vector3 v0 = vector3(vertices[index].x,              vertices[index].y,              vertices[index].z);
				vector3 v1 = vector3(vertices[index+Nplus1].x,       vertices[index+Nplus1].y,       vertices[index+Nplus1].z);
				vector3 v2 = vector3(vertices[index+1].x,            vertices[index+1].y,            vertices[index+1].z);
				vector3 v3 = vector3(vertices[index+(N-1)*Nplus1].x, vertices[index+(N-1)*Nplus1].y, vertices[index+(N-1)*Nplus1].z-N);
				vector3 v4 = vector3(vertices[index-1].x,            vertices[index-1].y,            vertices[index-1].z);
				vector3 n0 = ((v1-v0).cross(v2-v0)).unit();
				vector3 n1 = ((v2-v0).cross(v3-v0)).unit();
				vector3 n2 = ((v3-v0).cross(v4-v0)).unit();
				vector3 n3 = ((v4-v0).cross(v1-v0)).unit();
				vector3 n = (n0+n1+n2+n3).unit();
				vertices[index].nx = vertices[index+N*Nplus1].nx = n.x;
				vertices[index].ny = vertices[index+N*Nplus1].ny = n.y;
				vertices[index].nz = vertices[index+N*Nplus1].nz = n.z;
			} else if (i == 0) {
				vector3 v0 = vector3(vertices[index].x,         vertices[index].y,        vertices[index].z);
				vector3 v1 = vector3(vertices[index+Nplus1].x,  vertices[index+Nplus1].y, vertices[index+Nplus1].z);
				vector3 v2 = vector3(vertices[index+1].x,       vertices[index+1].y,      vertices[index+1].z);
				vector3 v3 = vector3(vertices[index-Nplus1].x,  vertices[index-Nplus1].y, vertices[index-Nplus1].z);
				vector3 v4 = vector3(vertices[index+(N-1)].x-N, vertices[index+(N-1)].y,  vertices[index+(N-1)].z);
				vector3 n0 = ((v1-v0).cross(v2-v0)).unit();
				vector3 n1 = ((v2-v0).cross(v3-v0)).unit();
				vector3 n2 = ((v3-v0).cross(v4-v0)).unit();
				vector3 n3 = ((v4-v0).cross(v1-v0)).unit();
				vector3 n = (n0+n1+n2+n3).unit();
				vertices[index].nx = vertices[index+N].nx = n.x;
				vertices[index].ny = vertices[index+N].ny = n.y;
				vertices[index].nz = vertices[index+N].nz = n.z;
			} else {
				vector3 v0 = vector3(vertices[index].x,        vertices[index].y,        vertices[index].z);
				vector3 v1 = vector3(vertices[index+Nplus1].x, vertices[index+Nplus1].y, vertices[index+Nplus1].z);
				vector3 v2 = vector3(vertices[index+1].x,      vertices[index+1].y,      vertices[index+1].z);
				vector3 v3 = vector3(vertices[index-Nplus1].x, vertices[index-Nplus1].y, vertices[index-Nplus1].z);
				vector3 v4 = vector3(vertices[index-1].x,      vertices[index-1].y,      vertices[index-1].z);
				vector3 n0 = ((v1-v0).cross(v2-v0)).unit();
				vector3 n1 = ((v2-v0).cross(v3-v0)).unit();
				vector3 n2 = ((v3-v0).cross(v4-v0)).unit();
				vector3 n3 = ((v4-v0).cross(v1-v0)).unit();
				vector3 n = (n0+n1+n2+n3).unit();
				vertices[index].nx = n.x;
				vertices[index].ny = n.y;
				vertices[index].nz = n.z;
			}
		}
	}

	delete [] height_map;

	createProgram(glProgram, glShaderV, glShaderF, "src/terrainv.sh", "src/terrainf.sh");
	vertex         = glGetAttribLocation(glProgram, "vertex");
	normal         = glGetAttribLocation(glProgram, "normal");
	texture        = glGetAttribLocation(glProgram, "texturec");
	colorc         = glGetAttribLocation(glProgram, "colorc");
	light_position = glGetUniformLocation(glProgram, "light_position");
	projection     = glGetUniformLocation(glProgram, "Projection");
	view           = glGetUniformLocation(glProgram, "View");
	model          = glGetUniformLocation(glProgram, "Model");
	heightmap      = glGetUniformLocation(glProgram, "heightmap");

	glGenBuffers(1, &vbo_vertices);
	glBindBuffer(GL_ARRAY_BUFFER, vbo_vertices);
	glBufferData(GL_ARRAY_BUFFER, sizeof(vertex_terrain)*(Nplus1)*(Nplus1), vertices, GL_DYNAMIC_DRAW);

	glGenBuffers(1, &vbo_indices);
	glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, vbo_indices);
	glBufferData(GL_ELEMENT_ARRAY_BUFFER, indices_count*sizeof(unsigned int), indices, GL_STATIC_DRAW);
}

cTerrain::~cTerrain() {
	delete [] vertices;
	delete [] indices;
}

vector3 cTerrain::color(double h, double min, double max) {
	// c0 is the color at minimum height, c0 + c1 at maximum
	vector3 c0( 66.0/255,  40.0/255,  18.0/255);
	vector3 c1(189.0/255, 185.0/255, 138.0/255);
	vector3 c = c0 + c1 * ((h - min) / (max - min));
	return c;
}

void cTerrain::release() {
	glDeleteBuffers(1, &vbo_indices);
	glDeleteBuffers(1, &vbo_vertices);
	releaseProgram(glProgram, glShaderV, glShaderF);
}

void cTerrain::render(glm::vec3& light_pos, glm::mat4& Projection, glm::mat4& View, glm::mat4& Model) {
	glUseProgram(glProgram);
	glUniform3f(light_position, light_pos.x, light_pos.y, light_pos.z);
	glUniformMatrix4fv(projection, 1, GL_FALSE, glm::value_ptr(Projection));
	glUniformMatrix4fv(view,       1, GL_FALSE, glm::value_ptr(View));
	glUniformMatrix4fv(model,      1, GL_FALSE, glm::value_ptr(Model));

	glBindBuffer(GL_ARRAY_BUFFER, vbo_vertices);
	glEnableVertexAttribArray(vertex);
	glVertexAttribPointer(vertex, 3, GL_FLOAT, GL_FALSE, sizeof(vertex_terrain), 0);
	glEnableVertexAttribArray(normal);
	glVertexAttribPointer(normal, 3, GL_FLOAT, GL_FALSE, sizeof(vertex_terrain), (char *)NULL + 12);
	glEnableVertexAttribArray(texture);
	glVertexAttribPointer(texture, 3, GL_FLOAT, GL_FALSE, sizeof(vertex_terrain), (char *)NULL + 24);
	glEnableVertexAttribArray(colorc);
	glVertexAttribPointer(colorc, 3, GL_FLOAT, GL_FALSE, sizeof(vertex_terrain), (char *)NULL + 36);

	glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, vbo_indices);
	glDrawElements(geometry ? GL_LINES : GL_TRIANGLES, indices_count, GL_UNSIGNED_INT, 0);
	Model = glm::translate(Model, glm::vec3(0.0f,0.0f,-N));
	glUniformMatrix4fv(model,      1, GL_FALSE, glm::value_ptr(Model));
	glDrawElements(geometry ? GL_LINES : GL_TRIANGLES, indices_count, GL_UNSIGNED_INT, 0);
	Model = glm::translate(Model, glm::vec3(N,0.0f,0.0f));
	glUniformMatrix4fv(model,      1, GL_FALSE, glm::value_ptr(Model));
	glDrawElements(geometry ? GL_LINES : GL_TRIANGLES, indices_count, GL_UNSIGNED_INT, 0);
	Model = glm::translate(Model, glm::vec3(0.0f,0.0f,N));
	glUniformMatrix4fv(model,      1, GL_FALSE, glm::value_ptr(Model));
	glDrawElements(geometry ? GL_LINES : GL_TRIANGLES, indices_count, GL_UNSIGNED_INT, 0);
}