import React, { useState, useCallback, useMemo, useEffect } from "react";
import Plot from "react-plotly.js";
import { dimensions, clusterColors, geneColorMapping } from "../constants";
import {
  FormControl,
  Select,
  MenuItem,
  Switch,
  FormControlLabel,
  InputLabel,
  Card,
  CardContent,
  Typography,
} from "@material-ui/core";
import { styled } from "@mui/system";
import { useAudioProfile } from "../contexts/AudioProfileContext";
import { makeStyles } from "@material-ui/core/styles";
import { Button } from "react-bootstrap";

const useStyles = makeStyles((theme) => ({
  formControl: {
    margin: theme.spacing(1),
    minWidth: "20%",
    maxWidth: "80%",
  },
  formControlPatient: {
    margin: theme.spacing(1),
    width: "40%",
  },
  selectEmpty: {
    marginTop: theme.spacing(2),
  },
  container: {
    padding: theme.spacing(2),
    display: "flex",
    flexDirection: "row", // Change to row to align items horizontally
    alignItems: "flex-start", // Align items at the start of the flex container
    justifyContent: "space-around", // Distribute space around items
    height: "100%",
    width: "100%",
    overflowY: "auto",
  },
  plotColumn: {
    display: "flex",
    flexDirection: "column", // Stack items vertically in each column
    justifyContent: "flex-start",
    alignItems: "center",
    height: "100%",
    flex: 1, // Allow columns to grow and fill the container
  },
  buttonGroup: {
    display: "flex",
    justifyContent: "left",
    alignItems: "left",
    marginBottom: theme.spacing(2),
    width: "100%", // Ensure button group takes full width of its container
    flexDirection: "row",
  },
  menuItem: {
    minWidth: "10%",
    maxWidth: "100%",
    color: "black",
  },
  clickedPointItem: {
    height: "30%",
    padding: "10px", // Add padding inside the card for better aesthetics (optional)
    backgroundColor: "white",
    color: "black",
    "&:first-child": {
      backgroundColor: "black",
      color: "white",
    },
    fontSize: "1rem",
  },
  button: {
    backgroundColor: "white",
    color: "black",
    width: "25%",
    height: "50px",
    borderRadius: "5px",
    boxShadow: "0px 4px 8px rgba(0, 0, 0, 0.2)", // Add shadow to create floating effect
    transition: "box-shadow 0.3s ease-in-out, transform 0.3s ease-in-out", // Smooth transition for hover effect
    "&:hover": {
      backgroundColor: "black",
      color: "white",
      boxShadow: "0px 6px 12px rgba(0, 0, 0, 0.3)", // Increase shadow to enhance floating effect on hover
      transform: "translateY(-2px)", // Move button up slightly to enhance floating effect
    },
    "&:disabled": {
      backgroundColor: "#f7f7f7",
      boxShadow: "none", // Remove shadow when button is disabled
    },
    margin: theme.spacing(2),
  },
}));

const StyledMenuItem = styled(MenuItem)(({ theme }) => ({
  "&.MuiMenuItem-root": {
    color: "black",
    "&:hover": {
      backgroundColor: "black",
      color: "white",
    },
  },
  "&.MuiMenuItem-selected": {
    backgroundColor: "black",
    color: "white",
  },
  width: "100%",
  fontStyle: "italic",
}));

// Debounce utility function
function debounce(func, wait) {
  let timeout;
  return function executedFunction(...args) {
    const later = () => {
      clearTimeout(timeout);
      func(...args);
    };
    clearTimeout(timeout);
    timeout = setTimeout(later, wait);
  };
}

const ThreeDClustering = ({
  clusteringDataGreedy,
  clusteringDataOriginal,
  geneCountsAfterResampling,
  table_clustering_data,
  hypothesis,
}) => {
  const classes = useStyles();
  const { predictionData } = useAudioProfile();
  const [selectedCluster, setSelectedCluster] = useState(null);
  const [selectedClusteringMethod, setSelectedClusteringMethod] =
    useState("original");
  const [selectedGene, setSelectedGene] = useState(null);
  const [showPercentage, setShowPercentage] = useState(false);
  const [cameraPosition, setCameraPosition] = useState(() => {
    // Attempt to load saved state from local storage
    const saved = localStorage.getItem("threeDClusteringCameraPosition");
    if (saved !== null) {
      try {
        // Attempt to parse the saved JSON string
        return JSON.parse(saved);
      } catch (e) {
        // Log the error to the console
        console.error("Failed to parse camera position from localStorage:", e);
      }
    }
    // Return a default value if parsing fails or if there's no saved state
    return {
      eye: { x: 1.25, y: 1.25, z: 1.25 },
    };
  });
  const [clickedPoints, setClickedPoints] = useState([]);
  const [visualizationMethod, setVisualizationMethod] =
    useState("cluster and gene");
  const [showHoverTemplate, setShowHoverTemplate] = useState(true);
  console.log("Show Hover Template: ", showHoverTemplate);

  let data =
    selectedClusteringMethod === "greedy"
      ? clusteringDataGreedy
      : clusteringDataOriginal;
  const clusters = useMemo(
    () => [...new Set(data.map((item) => item.cluster))],
    [data],
  );
  const genes = hypothesis.selectedGenes;

  const [selectedPatient, setSelectedPatient] = useState("No Patient");
  const [plotTitle, setPlotTitle] = useState("Cluster and Gene - Cluster 0");

  const handleClusteringMethodChange = (event) => {
    setSelectedClusteringMethod(event.target.value);
  };

  const handleToggleChange = (event) => {
    setShowPercentage(event.target.checked);
  };

  // Add an option for "All Patients"
  const allPatientsOption = "No Patient";

  const handlePatientChange = (event) => {
    const selectedValue = event.target.value;
    setSelectedPatient(selectedValue);
  };

  // Function to calculate the percentage of a gene in a cluster
  const calculatePercentage = (gene, clusterCount) => {
    const totalCount = geneCountsAfterResampling[gene] || 1; // Fallback to 1 to avoid division by zero
    return (clusterCount / totalCount) * 100;
  };

  const geneColor = useCallback(
    (gene) => geneColorMapping[gene] || "black",
    [],
  );
  const clusterColor = useCallback(
    (cluster) => clusterColors[cluster] || "black",
    [],
  );

  const barTraces = genes.map((gene) => {
    const geneCounts = clusters.map((cluster) => {
      let clusterPoints = data.filter((point) => point.cluster === cluster);
      // Removed filtering based on selectedGenes
      return clusterPoints.filter((point) => point.gene === gene).length;
    });
    // Conditionally format the name based on the selectedClusteringMethod
    let traceName;
    if (selectedClusteringMethod === "greedy") {
      // For greedy method, include gene, total count, and cluster numbers in the name
      const clustersWithGene = clusters
        .filter((cluster, index) => geneCounts[index] > 0)
        .join(", ");
      traceName = `${gene} (${geneCounts.reduce((a, b) => a + b, 0)}) - Cluster: ${clustersWithGene}`;
    } else {
      // For other methods, just include gene and total count
      traceName = `${gene} (${geneCounts.reduce((a, b) => a + b, 0)})`;
    }
    return {
      x: clusters,
      y: geneCounts,
      name: traceName,
      type: "bar",
      hovertemplate: "Cluster: %{x}<br>Count: %{y}<br>Gene: " + gene,
      marker: {
        color: geneColor(gene),
      },
    };
  });

  // Modify barTraces to show percentages if toggle is on
  const modifiedBarTraces = barTraces.map((trace) => {
    const geneName = trace.name.split(" ")[0]; // Extract the gene name
    if (showPercentage) {
      return {
        ...trace,
        y: trace.y.map((count, index) => calculatePercentage(geneName, count)), // Use the extracted gene name
        hovertemplate: `Cluster: %{x}<br>Percentage: %{y:.2f}%<br>Gene: ${geneName}`,
        name: `${geneName} (${trace.y.reduce((a, b) => a + b, 0)})`, // Add the total count to the name
      };
    }
    return trace;
  });

  // Function to generate traces for table_clustering_data
  const generateTableClusteringTraces = useCallback(() => {
    const tableData = table_clustering_data
      .filter(
        (dataPoint) => !selectedPatient || dataPoint.id === selectedPatient,
      )
      .map((dataPoint) => {
        let predictedGenesV4 = predictionData["audiogenev4"].find(
          (patient) => patient.id === dataPoint.id,
        );
        let predictedGenesV9 = predictionData["audiogenev9"].find(
          (patient) => patient.id === dataPoint.id,
        );
        let predictedGenes = predictedGenesV4 || predictedGenesV9;
        if (predictedGenes) {
          predictedGenes = predictedGenes.genes.slice(0, 3).join(", ");
        } else {
          predictedGenes = "";
        }
        return {
          x: [dataPoint.x],
          y: [dataPoint.y],
          z: [dataPoint.z],
          mode: "markers",
          marker: {
            size: 10,
            color: "red",
          },
          type: "scatter3d",
          name: `User Data: ID ${dataPoint.id}`,
          text: [
            `ID: ${dataPoint.id}<br>Top 3 Predicted Genes: ${predictedGenes}<br>Closest Gene Instances: ${dataPoint.closest_genes.slice(0, 3).join(", ")}<br>Cluster: ${dataPoint.cluster}`,
          ],
          hovertemplate: showHoverTemplate ? `%{text}` : "",
        };
      });
    return tableData;
  }, [table_clustering_data, selectedPatient, showHoverTemplate]);

  // // Modify the handleBarHover function
  // const handleBarHover = (data) => {
  //   if (visualizationMethod === "cluster and gene") {
  //     setSelectedGene(data.points[0].data.name.split(" ")[0]);
  //     setSelectedCluster(data.points[0].x);
  //     setPlotTitle(`Cluster and Gene - Cluster ${data.points[0].x}`); // Update plot title
  //   } else if (visualizationMethod === "byCluster") {
  //     setSelectedCluster(data.points[0].x);
  //     setPlotTitle(`By Cluster - Cluster ${data.points[0].x}`); // Update plot title
  //   }
  //   // For 'clustersVisualization', hovering over the bar chart does nothing
  // };

  // Modify the handleBarClick function
  const handleBarClick = (data) => {
    if (visualizationMethod === "cluster and gene") {
      setSelectedGene(data.points[0].data.name.split(" ")[0]);
      setSelectedCluster(data.points[0].x);
      setPlotTitle(`Cluster and Gene - Cluster ${data.points[0].x}`); // Update plot title
    } else if (visualizationMethod === "byCluster") {
      setSelectedCluster(data.points[0].x);
      setPlotTitle(`By Cluster - Cluster ${data.points[0].x}`); // Update plot title
    }
    // For 'clustersVisualization', clicking on the bar chart does nothing
  };

  const generateScatterTraces = useCallback(() => {
    let traces = [];

    switch (visualizationMethod) {
      case "clustersVisualization":
        // Logic for "Cluster Visualization"
        traces = clusters.map((cluster) => {
          let clusterPoints = data.filter((point) => point.cluster === cluster);
          return {
            x: clusterPoints.map((point) => point.x),
            y: clusterPoints.map((point) => point.y),
            z: clusterPoints.map((point) => point.z),
            mode: "markers",
            type: "scatter3d",
            name: `Cluster ${cluster}`,
            marker: {
              size: 3,
              color: clusterColor(cluster), // Color for each cluster
            },
            hoverinfo: !showHoverTemplate ? "text" : undefined,
            hovertemplate: !showHoverTemplate ? "%{text}" : undefined,
            text: clusterPoints.map(
              (point) => `Gene: ${point.gene}<br>Cluster: ${point.cluster}`,
            ),
          };
        });
        break;

      case "cluster and gene":
        // Logic for "Cluster and Gene"
        data.forEach((point) => {
          const color =
            point.cluster === selectedCluster
              ? geneColor(point.gene)
              : "lightgrey";
          const traceIndex = traces.findIndex(
            (trace) => trace.name === `Gene: ${point.gene}`,
          );
          if (traceIndex === -1) {
            traces.push({
              x: [point.x],
              y: [point.y],
              z: [point.z],
              mode: "markers",
              type: "scatter3d",
              name: `Gene: ${point.gene}`,
              marker: { size: 6, color: [color] },
              hoverinfo: !showHoverTemplate ? "text" : undefined,
              hovertemplate: !showHoverTemplate ? "%{text}" : undefined,
              text: [`Gene: ${point.gene}<br>Cluster: ${point.cluster}`],
            });
          } else {
            traces[traceIndex].x.push(point.x);
            traces[traceIndex].y.push(point.y);
            traces[traceIndex].z.push(point.z);
            traces[traceIndex].marker.color.push(color);
            traces[traceIndex].text.push(
              `Gene: ${point.gene}<br>Cluster: ${point.cluster}`,
            );
          }
        });
        break;

      case "byCluster":
        // Logic for "By Cluster"
        if (selectedCluster !== null) {
          // Initialize an object to hold two sets of points: those in the selected cluster and those not
          let inClusterPoints = {
            x: [],
            y: [],
            z: [],
            marker: { color: [], size: 3 },
          };
          let outOfClusterPoints = {
            x: [],
            y: [],
            z: [],
            marker: { color: [], size: 3 },
          };

          // Iterate over each data point to classify it as in-cluster or out-of-cluster
          data.forEach((point) => {
            if (point.cluster === selectedCluster) {
              inClusterPoints.x.push(point.x);
              inClusterPoints.y.push(point.y);
              inClusterPoints.z.push(point.z);
              inClusterPoints.marker.color.push("red"); // Highlight in-cluster points in red
            } else {
              outOfClusterPoints.x.push(point.x);
              outOfClusterPoints.y.push(point.y);
              outOfClusterPoints.z.push(point.z);
              outOfClusterPoints.marker.color.push("lightgrey"); // Grey out out-of-cluster points
            }
          });

          // Add the in-cluster points trace
          traces.push({
            ...inClusterPoints,
            mode: "markers",
            type: "scatter3d",
            name: `Cluster ${selectedCluster} (In)`,
            hoverinfo: !showHoverTemplate ? "text" : undefined,
            hovertemplate: !showHoverTemplate ? "%{text}" : undefined,
            text: data
              .filter((point) => point.cluster === selectedCluster)
              .map(
                (point) => `Gene: ${point.gene}<br>Cluster: ${point.cluster}`,
              ),
          });

          // Add the out-of-cluster points trace
          traces.push({
            ...outOfClusterPoints,
            mode: "markers",
            type: "scatter3d",
            name: `Cluster ${selectedCluster} (Out)`,
            hoverinfo: !showHoverTemplate ? "text" : undefined,
            hovertemplate: !showHoverTemplate ? "%{text}" : undefined,
            text: data
              .filter((point) => point.cluster !== selectedCluster)
              .map(
                (point) => `Gene: ${point.gene}<br>Cluster: ${point.cluster}`,
              ),
          });
        } else {
          // If no cluster is selected, default to light grey for all points
          traces.push({
            x: data.map((point) => point.x),
            y: data.map((point) => point.y),
            z: data.map((point) => point.z),
            mode: "markers",
            type: "scatter3d",
            name: "All Clusters",
            marker: {
              size: 3,
              color: "lightgrey",
            },
            hoverinfo: !showHoverTemplate ? "text" : undefined,
            hovertemplate: !showHoverTemplate ? "%{text}" : undefined,
            text: data.map(
              (point) => `Gene: ${point.gene}<br>Cluster: ${point.cluster}`,
            ),
          });
        }
        break;
      default:
        // Handle default case or other visualization methods if any
        break;
    }

    return traces;
  }, [
    data,
    visualizationMethod,
    selectedCluster,
    clusters,
    geneColor,
    clusterColor,
    showHoverTemplate,
  ]);

  const handleRelayout = useCallback((event) => {
    // Check if the relayout event includes camera changes
    if (event["scene.camera"]) {
      const newCameraPosition = {
        eye: event["scene.camera"].eye, // This controls the zoom
        up: event["scene.camera"].up,
        center: event["scene.camera"].center,
        projection: event["scene.camera"].projection,
      };

      // Use the debounced function to handle saving the camera position
      debouncedSaveCameraPosition(newCameraPosition);
    }
  }, []); // Dependencies array is empty if no external dependencies are needed

  // Create a debounced version of the function that saves the camera position
  const debouncedSaveCameraPosition = debounce((newCameraPosition) => {
    // Save the new camera position to localStorage or elsewhere
    setCameraPosition(newCameraPosition);
    localStorage.setItem(
      "threeDClusteringCameraPosition",
      JSON.stringify(newCameraPosition),
    );
  }, 150); // Debounce for 100 ms

  const handleClickedPoint = (event) => {
    // Assuming `event.points` is an array of points, where each point has information about the click event
    if (event.points && event.points.length > 0) {
      const clickedPoint = event.points[0]; // Get the first (and likely only) clicked point
      console.log("Clicked point data:", clickedPoint);
      // Use the point's index to access its specific text information
      const pointIndex = clickedPoint.pointNumber; // This might vary based on the event object structure
      let pointText = clickedPoint.data.text[pointIndex]; // Access the specific text for the clicked point

      // Remove HTML tags from the pointText
      pointText = pointText.replace(/<br>/g, ", "); // Replace <br> with a space
      pointText = pointText.replace(/<[^>]*>/g, ""); // Remove any other HTML tags
      // Now you can use `pointText` which contains only the information for the clicked point
      // For example, updating state or displaying the information somewhere in the UI
      setClickedPoints([pointText, ...clickedPoints]); // Assuming you're storing clicked points in a state
    }
  };

  const resetView = () => {
    debouncedSaveCameraPosition({
      eye: { x: 1.25, y: 1.25, z: 1.25 },
    });
  };

  // Ensure scatter traces are initialized on component mount and when dependencies change
  const modifiedScatterTraces = useMemo(() => {
    const scatterTraces = generateScatterTraces();
    const tableTraces = generateTableClusteringTraces();
    return [...scatterTraces, ...tableTraces];
  }, [generateScatterTraces, generateTableClusteringTraces]);

  // useMemo hook to memoize the plot layout configuration
  const plotLayout = useMemo(
    () => ({
      title: plotTitle,
      scene: {
        xaxis: { title: "X" },
        yaxis: { title: "Y" },
        zaxis: { title: "Z" },
        camera: cameraPosition, // Use the memoized cameraPosition state
      },
      margin: {
        l: 0,
        r: 260,
        b: 0,
        t: 100,
      },
      height: dimensions.height / 1.5,
      width: dimensions.width / 2,
      hovermode: "closest",
      hoverlabel: {
        bgcolor: "white",
        font: { color: "black" },
      },
    }),
    [plotTitle, cameraPosition],
  ); // Dependencies

  // Ensure scatter traces are initialized on component mount
  useEffect(() => {
    generateScatterTraces();
  }, [generateScatterTraces]);

  return (
    <div className={classes.container}>
      <div className={classes.plotColumn}>
        <div className={classes.buttonGroup}>
          <FormControl variant="outlined" className={classes.formControl}>
            <InputLabel id="clustering-method-label">
              Clustering Method
            </InputLabel>
            <Select
              labelId="clustering-method-label"
              id="clustering-method-select"
              value={selectedClusteringMethod}
              onChange={handleClusteringMethodChange}
              label="Clustering Method"
            >
              <StyledMenuItem value="original" className={classes.menuItem}>
                Original Clustering
              </StyledMenuItem>
              <StyledMenuItem value="greedy" className={classes.menuItem}>
                Greedy Clustering
              </StyledMenuItem>
            </Select>
          </FormControl>
          <FormControl variant="outlined" className={classes.formControl}>
            <InputLabel id="visualization-method-label">
              Visualization Method
            </InputLabel>
            <Select
              labelId="visualization-method-label"
              id="visualization-method-select"
              value={visualizationMethod}
              onChange={(event) => setVisualizationMethod(event.target.value)}
            >
              <StyledMenuItem
                value="cluster and gene"
                className={classes.menuItem}
              >
                Cluster and Gene
              </StyledMenuItem>
              <StyledMenuItem value="byCluster" className={classes.menuItem}>
                By Cluster
              </StyledMenuItem>
              <StyledMenuItem
                value="clustersVisualization"
                className={classes.menuItem}
              >
                Clusters Visualization
              </StyledMenuItem>
            </Select>
          </FormControl>
          <FormControlLabel
            control={
              <Switch checked={showPercentage} onChange={handleToggleChange} />
            }
            label="Show Percentage"
          />
        </div>
        <div className={classes.buttonGroup}>
          <FormControl
            variant="outlined"
            className={classes.formControlPatient}
          >
            <InputLabel id="patient-label">Patient</InputLabel>
            <Select
              labelId="patient-label"
              id="patient-select"
              value={selectedPatient}
              onChange={handlePatientChange}
              label="Patient"
            >
              {/* Add an option for "All Patients" */}
              <StyledMenuItem
                value={allPatientsOption}
                className={classes.menuItem}
              >
                No Patient
              </StyledMenuItem>
              {/* Aggregate and map through the predictionData to create a MenuItem for each unique patient across all versions */}
              {Object.values(predictionData)
                .flatMap((versionData) =>
                  versionData.map((patient) => patient.id),
                )
                .filter((value, index, self) => self.indexOf(value) === index) // Remove duplicates
                .sort((a, b) => {
                  const numberA = parseInt(a.split(" ")[1], 10); // Extract number from format "ID #"
                  const numberB = parseInt(b.split(" ")[1], 10);
                  return numberA - numberB; // Sort numerically
                })
                .map((patientId) => (
                  <StyledMenuItem
                    key={patientId}
                    value={patientId}
                    className={classes.menuItem}
                  >
                    {patientId}
                  </StyledMenuItem>
                ))}
            </Select>
          </FormControl>
        </div>
        <Plot
          data={modifiedBarTraces}
          layout={{
            title: "Gene Distribution in Clusters",
            barmode: "stack",
            xaxis: { title: "Clusters" },
            yaxis: { title: showPercentage ? "Percentage" : "Count" },
            margin: { l: 50, r: 275, b: 100 },
            height: dimensions.height / 1.5,
            width: dimensions.width / 2.25,
          }}
          onClick={handleBarClick}
          hoverlabel={{ bgcolor: "white", font: { color: "black" } }}
        />
      </div>
      <div className={classes.plotColumn}>
        <div className={classes.buttonGroup}>
          <Card
            style={{
              maxHeight: "175px",
              width: "100%",
              overflowY: "auto", // Allows for scrolling
              backgroundColor: "white",
              marginRight: "20px",
              display: "flex",
              flexDirection: "column",
              padding: "5px",
            }}
          >
            <Typography
              variant="h9"
              style={{
                marginBottom: "10px",
                justifyContent: "center",
                textAlign: "center",
              }}
            >
              Clicked Points
            </Typography>
            {clickedPoints
              .slice()
              .slice(0, 15)
              .map((pointInfo, index) => (
                <CardContent
                  key={index}
                  className={classes.clickedPointItem}
                  style={{
                    backgroundColor: index === 0 ? "black" : "white",
                    color: index === 0 ? "white" : "black",
                    fontSize: "1rem",
                  }}
                >
                  {pointInfo}
                </CardContent>
              ))}
          </Card>
          <FormControlLabel
            control={
              <Switch
                checked={showHoverTemplate}
                onChange={() => setShowHoverTemplate(!showHoverTemplate)}
              />
            }
            label="Show Hover Template"
            style={{ alignSelf: "flex-start" }} // Adjust positioning as needed
          />
          <Button
            variant="contained"
            onClick={resetView}
            className={classes.button}
          >
            Reset View
          </Button>
        </div>
        <Plot
          data={modifiedScatterTraces} // Use modifiedScatterTraces here
          layout={plotLayout}
          onRelayout={handleRelayout}
          onClick={(event) => {
            handleClickedPoint(event);
          }}
        />
      </div>
    </div>
  );
};

export default ThreeDClustering;
