#version 330 core

#include shader/lib/logdepthbuff.glsl

// Color buffer
uniform sampler2D u_texture0;
// Depth buffer (log)
uniform sampler2D u_texture1;
// Camera matrix
uniform mat4 u_modelView;
// Time in seconds
uniform float u_time;
// Floating position
uniform vec3 u_pos;
// Zfar
uniform vec2 u_zfark;
// Viewport
uniform vec2 u_viewport;
// Use additional.x for bend scaling
uniform vec4 u_additional;
// Size of black hole
uniform float u_size;

in vec2 v_texCoords;
in vec3 v_ray;
layout (location = 0) out vec4 fragColor;

vec3 prj(vec3 wc, mat4 combined){
    vec4 w = vec4(wc, 1.0);
    float lw = 1.0 / (w.x * combined[0][3] + w.y * combined[1][3] + w.z * combined[2][3] + combined[3][3]);
    vec4 res = (combined * w) * lw;
    //vec3 res = vec3((w.x * combined[0][0] + w.y * combined[1][0] + w.z * combined[2][0] + combined[3][0]) * lw,
    //                (w.x * combined[0][1] + w.y * combined[1][1] + w.z * combined[2][1] + combined[3][1]) * lw,
    //                (w.x * combined[0][2] + w.y * combined[1][2] + w.z * combined[2][2] + combined[3][2]) * lw);
    return res.xyz;
}

vec3 project(vec3 wc, mat4 combined){
    vec3 sc = prj(wc, combined);
    return (sc + 1.0) / 2.0;
}

float hash(float x) {
    return fract(sin(x) * 152754.742);
}

float hash(vec2 x) {
    return hash(x.x + hash(x.y));
}

float value(vec2 p, float f)//value noise
{
    float bl = hash(floor(p * f + vec2(0.0, 0.0)));
    float br = hash(floor(p * f + vec2(1.0, 0.0)));
    float tl = hash(floor(p * f + vec2(0.0, 1.0)));
    float tr = hash(floor(p * f + vec2(1.0, 1.0)));

    vec2 fr = fract(p * f);
    fr = (3.0 - 2.0 * fr) * fr * fr;
    float b = mix(bl, br, fr.x);
    float t = mix(tl, tr, fr.x);
    return mix(b, t, fr.y);
}

vec4 raymarchDisk(vec3 ray, vec3 zeroPos, float s) {
    return vec4(1.,1.,1.,0.); //no disk
}


void blackhole(out vec4 colOut, in vec2 fragCoord, in vec3 ray, in vec3 pos, in float s) {
    colOut = vec4(0.0);

    vec4 col = vec4(0.0);
    vec4 outCol =vec4(100.0);
    vec2 tc = fragCoord;
    vec3 ray_mild = ray;

    for (int disks = 0; disks < 20; disks++) //steps
    {
        for (int h = 0; h < 6; h++) //reduces tests for exit conditions (to minimise branching)
        {
            float dotpos = dot(pos, pos);
            float invDist = inversesqrt(dotpos); //1/distance to BH
            float centDist = dotpos * invDist; //distance to BH
            float stepDist = 0.92 * abs(pos.y /(ray.y)); //conservative distance to disk (y==0)
            float farLimit = centDist * 0.5; //limit step size far from to BH
            float closeLimit = centDist * 0.1 + 0.05 * centDist * centDist * (1.0 / u_size); //limit step size close to BH
            stepDist = min(stepDist, min(farLimit, closeLimit));

            // Depth buffer
            if (centDist >= s) {
                colOut = clamp(vec4(texture(u_texture0, fragCoord).xyz, 1.0), 0.0, 1.0);
                return;
            }

            float invDistSqr = invDist * invDist;
            float bendForce = stepDist * invDistSqr * u_size * 0.08; //bending force
            ray =  normalize(ray - (bendForce * invDist) * pos); //bend ray towards BH
            ray_mild = normalize(ray_mild - (bendForce * u_additional.x * invDist) * pos);
            pos += stepDist * ray;

        }


        float dist2 = length(pos);

        if (dist2 < u_size * 0.1) //ray sucked in to BH
        {
            outCol =  vec4(col.rgb * col.a, 1.0);
            break;
        }
        else if (dist2 > u_size * 1000.0)//ray escaped BH
        {
            vec2 bend_tc = project(ray_mild, u_modelView).xy;
            vec4 bg = texture(u_texture0, bend_tc); // Color of the scene before this shader was run
            outCol = vec4(col.rgb * col.a + bg.rgb * (1.0 - col.a), 1.0);
            break;
        }
        else if (abs(pos.y) <= u_size * 0.001) //ray hit accretion disk
        {
            vec4 diskCol = raymarchDisk(ray, pos, s); //render disk
            pos.y = 0.0;
            pos += abs(u_size * 0.001 / ray.y) * ray;
            col = vec4(diskCol.rgb * (1.0 - col.a) + col.rgb, col.a + diskCol.a * (1.0 - col.a));
        }
    }

    //if the ray never escaped or got sucked in
    if (outCol.r == 100.0)
    outCol = vec4(col.rgb, 1.0);

    col = outCol;

    colOut += col;

}

void main(){
    // ray direction
    vec3 ray = normalize(v_ray);
    // floating position (camPos - pos)
    vec3 pos = u_pos;

    vec3 col = texture(u_texture0, v_texCoords).rgb; // Color of the scene before this shader was run
    float depth = 1.0 / recoverWValue(texture(u_texture1, v_texCoords).r, u_zfark.x, u_zfark.y); // Logarithmic depth buffer
    depth *= length(ray);
    blackhole(fragColor, v_texCoords, ray, pos, depth);
}


