import React, { useEffect, useState } from "react";
import * as THREE from "three";
import { extendMaterial } from "../../lib/ExtendMaterial";
import { MeshBVH, acceleratedRaycast } from "three-mesh-bvh";
import { getIrradianceCoefficients } from "../../lib/Akima";
import { calculateIrradiance, createLUTTexture, createSpotlight } from "../../util";
import HeatmapTooltip, { TooltipData } from "../HeatmapTooltip";

// Accelerated raycasting for meshes
THREE.Mesh.prototype.raycast = acceleratedRaycast;

interface IrradianceHeatmapViewProps {
  scene: THREE.Scene;
  model: THREE.Object3D;
  lightPositions: THREE.Vector3[];
  renderer: THREE.Renderer;
  camera: THREE.Camera;
  canvas: React.RefObject<HTMLCanvasElement>
}

function addFoVCone(position: THREE.Vector3, scene: THREE.Scene, index: number) {
  const radius = (position.z * (1/Math.sqrt(3))); // tan(30deg) = o / h // h * tan(30deg) = o
  const height = position.z;
  const radialSegments = 32;
  const coneGeometry = new THREE.ConeGeometry(radius, height, radialSegments);

  const material = new THREE.MeshBasicMaterial({ color: 0x04fcdc, transparent: true, opacity: 0.5, alphaHash: true });
  const cone = new THREE.Mesh(coneGeometry, material);
  cone.name = `FOV ${index}`;

  cone.position.set(position.x, position.y, (position.z / 2));
  cone.rotation.x = Math.PI / 2;

  scene.add(cone);
}

export default function IrradianceHeatmapView({
  scene,
  model,
  lightPositions,
  renderer,
  camera,
  canvas
}: IrradianceHeatmapViewProps): React.ReactElement {
  const {minDistInvSquareCoeff, maxDistInvSquareCoeff, minX} = getIrradianceCoefficients();
  const raycaster = new THREE.Raycaster();
  const [markerMesh, setMarkerMesh] = useState<THREE.Mesh | null>(null);
  const [tooltip, setTooltip] = useState<TooltipData | null>(null);

  const checkShadowing = (point: THREE.Vector3) => {
    const hitByLights: string[] = [];
    const distancesToLights: number[] = [];

    lightPositions.forEach((position, index) => {
      raycaster.set(position, point.clone().sub(position).normalize());
      const shadowIntersects = raycaster.intersectObject(model, true);

      const isBlocked =
        shadowIntersects.length > 0 &&
        (shadowIntersects[0].distance).toFixed(8) < (point.distanceTo(position)).toFixed(8);

      if (!isBlocked) {
        hitByLights.push(`Zener ${index + 1}`);
        distancesToLights.push(point.distanceTo(position));
      }
    });

    return { hitByLights, distancesToLights };
  };


  useEffect(() => {
    if (!scene || !lightPositions.length || !renderer || !canvas) {
      return;
    }

    renderer.shadowMap.enabled = true;
    renderer.shadowMap.type = THREE.BasicShadowMap;

    // Add multiple lights
    const spotLights = lightPositions.map((position) => {
      const spotLight = createSpotlight(position);
      scene.add(spotLight);
      return spotLight;
    });

    // Create a transparent whitish marker
    const markerGeometry = new THREE.SphereGeometry(0.05, 16, 16);
    const markerMaterial = new THREE.MeshBasicMaterial({
      color: 0x000000,
      opacity: 0.5,
      transparent: true,
    });
    const markerMesh = new THREE.Mesh(markerGeometry, markerMaterial);
    scene.add(markerMesh);
    setMarkerMesh(markerMesh);

    model.traverse((child) => {
      if (child instanceof THREE.Mesh && child.geometry instanceof THREE.BufferGeometry) {
        const bvh = new MeshBVH(child.geometry);
        child.geometry.boundsTree = bvh;
      }
    });

    const numberOfRays = 100;

    // Generate a random direction for the ray
    const directions = Array.from({ length: numberOfRays }, () => {
      const pi = Math.acos(2 * Math.random() - 1);
      const theta = 2 * Math.PI * Math.random();
      return new THREE.Vector3(
        Math.sin(pi) * Math.cos(theta),
        Math.sin(pi) * Math.sin(theta),
        Math.cos(pi)
      );
    });


    const collectedPoints: {
      position: THREE.Vector3;
      zeners: string[];
      distances: number[];
      irradiance?: number;
    }[] = [];

    // Collect points and calculate distances for each light position and direction
    lightPositions.forEach((position, index) => {
      directions.forEach((direction) => {
        raycaster.set(position, direction.normalize());
        const intersects = raycaster.intersectObject(model, true);
        if (intersects.length > 0) {
          const intersect = intersects[0];
          const point = intersect.point;
          const { hitByLights, distancesToLights } = checkShadowing(point);

          const existingPoint = collectedPoints.find((item) =>
            item.position.equals(point)
          );

          // If there is exisiting point, skip it
          if (!existingPoint) {
            collectedPoints.push({
              position: point.clone(),
              zeners: hitByLights,
              distances: distancesToLights,
            });
          }

        }
      });
      addFoVCone(position, scene, index);
    });

    // Get min, max values
    const { minIrr, maxIrr } = calculateIrradiance(collectedPoints);
    
    if (minIrr === undefined || maxIrr === undefined) {
      console.error("No valid min or max irr found.");
    }

    const lutTexture = createLUTTexture(200); // Use the extracted function
    lutTexture.needsUpdate = true;

    const combinedMaterial = extendMaterial(THREE.MeshPhongMaterial, {
      class: THREE.ShaderMaterial,
      explicit: true,
      uniforms: {
        lutTexture: { value: lutTexture },
        minIrr: { value: minIrr },
        maxIrr: { value: maxIrr },
        lightPositions: { value: lightPositions },
        a: { value: minDistInvSquareCoeff },
        b: { value: maxDistInvSquareCoeff },
      },
    
      header: `
        varying vec3 vWorldPosition;
        uniform sampler2D lutTexture;
        uniform float minIrr;
        uniform float maxIrr;
        uniform float a;
        uniform float b;
        uniform vec3 lightPositions[${lightPositions.length}];
      `,
    
      vertex: {
        "#include <project_vertex>": `
          vWorldPosition = (modelMatrix * vec4(position, 1.0)).xyz;
          gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
        `,
      },
      fragment: {
        "#include <color_fragment>": `
          float totalIrradiance = 0.0;

          for (int i = 0; i < ${lightPositions.length}; i++) {
            float dist = length(vWorldPosition - lightPositions[i]);
            float irr = 0.0;
            if (dist > 0.0) { 
              irr = dist < ${minX} ? (a / (dist * dist)) : (b / (dist * dist));
              totalIrradiance += irr;
            }
          }
  
          // Normalize irradiance to [0, 1]
          float normalizedIrradiance = clamp((totalIrradiance - minIrr) / (maxIrr - minIrr), 0.0, 1.0);
    
          // Invert for LUT mapping (0 -> red, 1 -> blue)
          float intensity = 1.0 - normalizedIrradiance;
          vec3 color = texture2D(lutTexture, vec2(intensity, 0.5)).rgb;
          diffuseColor.rgb = color;
        `,
      },
      material: {
        polygonOffset: true,
        polygonOffsetFactor: -0.1,
        side: THREE.DoubleSide,
      },
    });
    
    model.traverse((child) => {
      if (child instanceof THREE.Mesh) {
        child.castShadow = true;
        child.receiveShadow = true;
        child.material = combinedMaterial;
      }
    });

    return () => {
      model.traverse((child) => {
        if (child instanceof THREE.Mesh) {
          child.material.dispose();
        }
      });
      spotLights.forEach((light) => scene.remove(light));
      scene.remove(markerMesh);
      markerMesh.material.dispose();
    };

  }, [model, lightPositions, scene, renderer]);

  // Handle hover and click events
  useEffect(() => {
    if (!markerMesh || !camera ) {
      return;
    }    

    const onMouseMove = (event: MouseEvent) => {
      if (!canvas.current) {
        return;
      }

      const rect = canvas.current.getBoundingClientRect();
      const mouse = new THREE.Vector2(
        ((event.clientX - rect.left) / rect.width) * 2 - 1,
        -((event.clientY - rect.top) / rect.height) * 2 + 1
      );

      raycaster.setFromCamera(mouse, camera);
      const intersects = raycaster.intersectObject(model, true);

      if (intersects.length > 0) {
        const intersect = intersects[0];
        const point = intersect.point;
        const { hitByLights, distancesToLights } = checkShadowing(point);

        const newPoint = {
          position: point.clone(),
          zeners: hitByLights,
          distances: distancesToLights,
        };

        const { updatedPoints } = calculateIrradiance( [newPoint] );

        markerMesh.position.copy(point);
        markerMesh.visible = true;

        setTooltip({
          x: event.clientX,
          y: event.clientY,
          pointData: updatedPoints[0],
        });
        
      } else {
        markerMesh.visible = false;
        setTooltip(null);
      }
    };

    window.addEventListener("mousemove", onMouseMove);

    return () => {
      window.removeEventListener("mousemove", onMouseMove);
      if (markerMesh) {
        scene.remove(markerMesh);
      }
    };
  }, [ markerMesh, camera] );

  return (
    <>
      <HeatmapTooltip tooltip={tooltip} />
    </>
  );
}