import React, { useEffect, useState } from "react";
import geneXData from "./genes_x.json";
import geneZData from "./genes_z.json";
import {
  Box,
  Button,
  FormControl,
  InputLabel,
  MenuItem,
  Select,
  TextField,
  Grid,
  Typography,
  CardContent,
} from "@mui/material";
import { Card } from "@mui/material";
import { makeStyles } from "@material-ui/core";
import { dimensions } from "../constants";
import { styled } from "@mui/system";
import Switch from "@mui/material/Switch";
import FormControlLabel from "@mui/material/FormControlLabel";
import Plotly from "plotly.js-dist";
import { useAudioProfile } from "../contexts/AudioProfileContext";

// Create a theme instance.
const useStyles = makeStyles((theme) => ({
  formControl: {
    minWidth: 120,
    alignContent: "left",
    width: "20%",
    textAlign: "left",
    marginRight: "10px",
  },
  container: {
    display: "flex",
    flexDirection: "column",
    alignItems: "center",
    width: "100%",
    height: "100%",
  },
  menuItem: {
    color: "black",
    "&:hover": {
      backgroundColor: "black",
      color: "white",
    },
    "&:focus": {
      backgroundColor: "white",
      color: "black",
    },
    "&.Mui-selected": {
      backgroundColor: "black",
      color: "white",
    },
  },
  select: {
    width: "100%",
    fontStyle: "italic",
  },
  textField: {
    width: "100%",
  },
  plotContainer: {
    display: "flex",
    flexDirection: "row",
    alignItems: "center",
    justifyContent: "center",
    width: "100%",
    height: "100%",
    overflowY: "auto",
  },
  plotContainerWalkthrough: {
    display: "flex",
    flexDirection: "column",
    alignItems: "center",
    justifyContent: "center",
    width: "100%",
    height: "100%",
    overflowY: "auto",
  },
  clickedPointsList: {
    height: dimensions.height / 2, // Adjust the height as needed
    overflowY: "auto", // Ensures scrollbar appears only when needed
    backgroundColor: "white",
    width: "30%", // Adjust the width as needed
    marginRight: "20px",
    display: "flex",
    flexDirection: "column",
    padding: "10px",
  },
  clickedPointItem: {
    height: "fit-content", // Adjust the height as needed
    padding: "10px", // Add padding inside the card for better aesthetics (optional)
    backgroundColor: "white",
    color: "black",
    "&:first-child": {
      backgroundColor: "black",
      color: "white",
    },
  },
}));

const StyledMenuItem = styled(MenuItem)(({ theme }) => ({
  "&.MuiMenuItem-root": {
    color: "black",
    "&:hover": {
      backgroundColor: "black",
      color: "white",
    },
  },
  "&.MuiMenuItem-selected": {
    backgroundColor: "black",
    color: "white",
  },
  width: "100%",
  fontStyle: "italic",
}));

let genes = Object.keys(geneXData);
genes = genes.sort();

const frequencies = [125, 250, 500, 1000, 1500, 2000, 3000, 4000, 6000, 8000];

const SurfaceViewer = ({ usingWalkthrough }) => {
  const classes = useStyles();

  const { processedTableData } = useAudioProfile();

  const [selectedGene, setSelectedGene] = useState(genes[0]);
  const [cameraEye, setCameraEye] = useState({
    x: 2.2149571135580044,
    y: -1.6606448968632268,
    z: 0.37864572699016247,
  });
  const [audiogram, setAudiogram] = useState({
    age: "",
    values: Array(frequencies.length).fill(""),
  });
  const [showLabels, setShowLabels] = useState(true);
  const [frozenPoint, setFrozenPoint] = useState(null);
  const [clickedPoints, setClickedPoints] = useState([]);
  const [selectedPatientAudiogram, setSelectedPatientAudiogram] =
    useState(null);

  const freezeAtPoint = (point) => {
    setClickedPoints((prevPoints) => {
      const pointExists = prevPoints.some(
        (prevPoint) =>
          prevPoint.x === point.x &&
          prevPoint.y === point.y &&
          prevPoint.z === point.z,
      );
      if (!pointExists) {
        return [point, ...prevPoints];
      }
      return prevPoints;
    });
  };

  useEffect(() => {
    Generate3DPlot();
  }, [selectedGene, audiogram, selectedPatientAudiogram]);

  const reversedViridisColorScale = [
    [0, "#440154"], // Dark purple, was at the top (1.0), now at the bottom (0.0)
    [0.1, "#482878"], // Purple
    [0.2, "#3e4989"], // Blue-purple
    [0.3, "#31688e"], // Blue
    [0.4, "#26828e"], // Cyan
    [0.5, "#1f9e89"], // Green-cyan
    [0.6, "#35b779"], // Green
    [0.7, "#6ece58"], // Yellow-green
    [0.8, "#b5de2b"], // Yellow
    [0.9, "#fde725"], // Yellow-white
    [1, "#f0f921"], // Bright yellow, was at the bottom (0.0), now at the top (1.0)
  ];

  const customColorBarScale = [
    [0, "#440154"], // Dark purple, at the bottom
    [0.1, "#482878"], // Purple
    [0.2, "#3e4989"], // Blue-purple
    [0.3, "#31688e"], // Blue
    [0.4, "#26828e"], // Cyan
    [0.5, "#1f9e89"], // Green-cyan
    [0.6, "#35b779"], // Green
    [0.7, "#6ece58"], // Yellow-green
    [0.8, "#b5de2b"], // Yellow
    [0.9, "#fde725"], // Yellow-white
    [1, "#fde725"], // Yellow, at the top
  ];

  // Create a dummy trace for the custom color bar
  const colorBarTrace = {
    z: [[0, 1]], // A 2D array with a single row
    x: [0, 1], // X values don't matter, but need to be present
    y: [0, 1], // Y values span from 0 to 1
    type: "heatmap",
    colorscale: customColorBarScale,
    showscale: false, // We don't want the default color bar
    hoverinfo: "none", // Disable hover info for the color bar
  };

  const Generate3DPlot = () => {
    var keyName = selectedGene;
    var yDataReverse = frequencies; // Convert Hz to kHz and reverse for the plot
    var xData = geneXData[keyName].slice();
    var zData = geneZData[keyName].slice();

    // Modify the hovertemplate to always show the frozen point's details
    const hovertemplate = frozenPoint
      ? `Age: ${frozenPoint.x}<br>Frequency (kHz): ${frozenPoint.y}<br>dB HL: ${frozenPoint.z} dB<extra></extra>`
      : "Age: %{x}<br>Frequency (kHz): %{y}<br>dB HL: %{z} dB<extra></extra>";

    var data = [
      {
        x: xData,
        y: yDataReverse,
        z: zData,
        type: "surface",
        hovertemplate: showLabels ? hovertemplate : undefined,
        hoverinfo: showLabels ? undefined : "none",
        colorscale: customColorBarScale,
        reversescale: true,
        showscale: true,
        colorbar: {
          tickvals: [-10, 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100],
          ticktext: [
            "0",
            "10",
            "20",
            "30",
            "40",
            "50",
            "60",
            "70",
            "80",
            "90",
            "100",
          ],
          len: 0.5,
          y: 0.5,
          thickness: 20,
          tickmode: "array",
          tickangle: -45,
          tickfont: { size: 15 },
        },
      },
    ];

    // Add user's audiogram to the plot if values are present
    if (audiogram.values.some((val) => val !== "")) {
      const age = parseFloat(audiogram.age);
      let audiogramValues = audiogram.values.map((val) =>
        val === "" ? null : Number(val),
      );
      let yDataAudiogram = frequencies; // Convert Hz to kHz for the audiogram

      // Filter out the frequencies and corresponding values for the plot
      const filteredAudiogramValues = audiogramValues.filter(
        (val) => val !== null,
      );
      const filteredFrequencies = yDataAudiogram.filter(
        (val, index) => audiogramValues[index] !== null,
      );

      data.push({
        x: Array(filteredAudiogramValues.length).fill(age),
        y: filteredFrequencies,
        z: filteredAudiogramValues,
        hovertemplate: showLabels
          ? "Age (Years): %{x}<br>Frequency (kHz): %{y}<br>dB HL: %{z} dB<extra></extra>"
          : undefined,
        hoverinfo: showLabels ? undefined : "none",
        mode: "lines",
        type: "scatter3d",
        line: {
          color: "red",
          width: 6,
        },
        name: `Patient Audiogram`,
      });
    }
    if (
      usingWalkthrough &&
      selectedPatientAudiogram &&
      selectedPatientAudiogram !== "No Patient"
    ) {
      console.log(processedTableData);
      const patientDataEntries = processedTableData.filter(
        (data) => data.id === selectedPatientAudiogram,
      );
      if (patientDataEntries.length) {
        patientDataEntries.forEach((patientData, index) => {
          const age = patientData.age;
          const audiogramValues = frequencies.map(
            (freq) => patientData[freq + " dB"],
          );
          console.log("The patient's audiogram values are: ", audiogramValues);

          // Determine the color based on the index to create a gradient from light red to dark red
          const colorGradient = [
            "#ffcccc",
            "#ff9999",
            "#ff6666",
            "#ff3333",
            "#ff0000",
          ];
          let color;
          if (index < colorGradient.length) {
            color = colorGradient[index];
          } else {
            color = colorGradient[colorGradient.length - 1]; // Use the darkest red for overflow
          }

          data.push({
            x: Array(audiogramValues.length).fill(age),
            y: frequencies,
            z: audiogramValues,
            hovertemplate: showLabels
              ? `Patient: ${patientData.id}<br> Age (Years): %{x}<br>Frequency (kHz): %{y}<br>dB HL: %{z} dB<extra></extra>`
              : undefined,
            hoverinfo: showLabels ? undefined : "none",
            mode: "lines",
            type: "scatter3d",
            line: {
              color: color,
              width: 6,
            },
            name: `Patient ${patientData.id}, Age: ${Number.isInteger(age) ? age : age.toFixed(2)}`,
            showlegend: true, // Ensure the legend is shown for this trace
          });
        });
      }
    }

    let width = dimensions.width / 2;
    let height = dimensions.height / 1.4;
    if (usingWalkthrough) {
      width = dimensions.width / 2.4;
      height = dimensions.height / 1.4;
    }

    var layout = {
      width: width,
      height: height,
      title: `<i>${keyName}</i>`,
      scene: {
        xaxis: {
          title: "Age (Years)",
          hoverlabel: { bgcolor: "#FFF" }, // Optional: style for hover label
          showline: false,
          zeroline: false,
          showticklabels: true, // Set to false if you want to hide tick labels
        },
        yaxis: {
          title: "Frequency (Hz)",
          tickvals: frequencies,
          ticktext: frequencies.map((f) =>
            f >= 1000 ? `${f / 1000}k` : `${f}`,
          ),
          range: [Math.min(...frequencies), Math.max(...frequencies)].reverse(),
          hoverlabel: { bgcolor: "#FFF", font: { color: "#000" } }, // Optional: style for hover label
          tickangle: -45,
          type: "log", // Specify that this axis uses a logarithmic scale
          autorange: true,
          showline: false,
          zeroline: false,
          showticklabels: true, // Set to false if you want to hide tick labels
        },
        zaxis: {
          title: "dB HL",
          range: [130, -10], // Reversed range for the z-axis
          autorange: false, // Disable autorange to use the specified range
          hoverlabel: { bgcolor: "#FFF" }, // Optional: style for hover label
          tickangle: -45,
          showline: false,
          zeroline: false,
          showticklabels: true, // Set to false if you want to hide tick labels
        },
        camera: { eye: cameraEye },
      },
      xaxis: { visible: false }, // Hide axes for the color bar
      yaxis: { visible: false },
      margin: { t: 150, r: 300, b: 0, l: 0 }, // Minimize margins
    };

    Plotly.newPlot("plot-plot-3D", data, layout).then(() => {
      const plotElement = document.getElementById("plot-plot-3D");
      plotElement.on("plotly_click", (eventData) => {
        const point = {
          x: eventData.points[0].x,
          y: eventData.points[0].y,
          z: eventData.points[0].z,
        };
        freezeAtPoint(point);
      });
    });
  };

  // Effect hook to re-render the plot when the frozen line data changes
  // Effect hook to re-render the plot when the frozen point changes
  useEffect(() => {
    Generate3DPlot();
  }, [frozenPoint, selectedGene, audiogram, showLabels, clickedPoints]);

  const onToggleLabels = () => {
    setShowLabels(!showLabels);
  };

  useEffect(() => {
    Generate3DPlot();
  }, [selectedGene, audiogram, showLabels, cameraEye]);

  const handleGeneChange = (event) => {
    setSelectedGene(event.target.value);
  };

  const handleAudiogramChange = (index) => (event) => {
    const newAudiogramValues = [...audiogram.values];
    newAudiogramValues[index] = event.target.value;
    setAudiogram({ ...audiogram, values: newAudiogramValues });
  };

  const handleAgeChange = (event) => {
    setAudiogram({ ...audiogram, age: event.target.value });
  };

  const resetGraph = () => {
    setAudiogram({ age: "", values: Array(frequencies.length).fill("") });
    Generate3DPlot();
  };

  const resetCamera = () => {
    setCameraEye({
      x: 2.2149571135580044,
      y: -1.6606448968632268,
      z: 0.37864572699016247,
    });
  };

  const resetClickedPoints = () => {
    setClickedPoints([]);
  };

  const clickedPointsCard = (
    <>
      {usingWalkthrough && (
        <Typography
          variant="h6"
          style={{
            marginBottom: "10px",
            justifyContent: "center",
            textAlign: "center",
          }}
        >
          Clicked Points
        </Typography>
      )}
      <Card
        sx={{
          maxHeight: "100%", // Adjust the height as needed
          maxWidth: "80%", // Adjust the width as needed
          overflowY: "auto", // Allows for scrolling
          overflowX: "auto", // Hide horizontal scrollbar
          backgroundColor: "white",
          marginRight: "20px",
          width: usingWalkthrough ? "80%" : "20%", // Adjust the width as needed
          display: "flex",
          flexDirection: usingWalkthrough ? "row" : "column", // Change flex direction based on usingWalkthrough
          padding: "10px",
        }}
      >
        {!usingWalkthrough && (
          <Typography
            variant="h6"
            style={{
              marginBottom: "10px",
              justifyContent: "center",
              textAlign: "center",
            }}
          >
            Clicked Points
          </Typography>
        )}
        {clickedPoints.map((point, index) => (
          <CardContent
            key={index}
            className={classes.clickedPointItem}
            style={{
              backgroundColor: index === 0 ? "black" : "white",
              color: index === 0 ? "white" : "black",
            }}
          >
            <Typography>Age: {point.x}</Typography>
            <Typography>Frequency (kHz): {point.y}</Typography>
            <Typography>dB HL: {point.z} dB</Typography>
          </CardContent>
        ))}
      </Card>
    </>
  );

  return (
    <Box sx={{ flexGrow: 1, padding: 2 }}>
      <Grid container spacing={2}>
        <Grid item xs={12}>
          <FormControl className={classes.formControl} variant="filled">
            <InputLabel id="gene-select-label">Gene</InputLabel>
            <Select
              labelId="gene-select-label"
              id="gene-select"
              value={selectedGene}
              onChange={handleGeneChange}
              label="Gene"
              className={classes.select}
            >
              {genes.map((gene) => (
                <StyledMenuItem value={gene}>{gene}</StyledMenuItem>
              ))}
            </Select>
          </FormControl>
          <FormControlLabel
            control={
              <Switch checked={showLabels} onChange={() => onToggleLabels()} />
            }
            label="Show Labels"
            sx={{
              marginLeft: "10px",
              marginTop: "10px",
            }}
          />
          <Button
            sx={{
              marginTop: 1,
              width: "15%",
              color: "white",
              backgroundColor: "black",
              "&:hover": {
                backgroundColor: "white",
                color: "black",
              },
            }}
            onClick={resetCamera}
          >
            Reset View
          </Button>
          <Button
            sx={{
              marginTop: 1,
              marginLeft: "10px",
              width: "15%",
              color: "white",
              backgroundColor: "black",
              "&:hover": {
                backgroundColor: "white",
                color: "black",
              },
            }}
            onClick={resetClickedPoints}
          >
            Reset Clicked Points
          </Button>
          {usingWalkthrough && (
            <FormControl className={classes.formControl} variant="filled">
              <InputLabel id="patient-audiogram-select-label">
                Patient Audiogram
              </InputLabel>
              <Select
                labelId="patient-audiogram-select-label"
                id="patient-audiogram-select"
                value={selectedPatientAudiogram}
                className={classes.select}
                sx={{ width: "100%", marginLeft: "10px" }}
                onChange={(event) =>
                  setSelectedPatientAudiogram(event.target.value)
                }
              >
                <StyledMenuItem value="No Patient" className={classes.menuItem}>
                  No Patient
                </StyledMenuItem>
                {processedTableData
                  .filter(
                    (value, index, self) =>
                      self.findIndex((v) => v.id === value.id) === index,
                  )
                  .map((data, index) => (
                    <StyledMenuItem key={index} value={data.id}>
                      {`Patient ${data.id}`}
                    </StyledMenuItem>
                  ))}
              </Select>
            </FormControl>
          )}
        </Grid>
      </Grid>
      <Box
        className={
          usingWalkthrough
            ? classes.plotContainerWalkthrough
            : classes.plotContainer
        }
      >
        {!usingWalkthrough && clickedPointsCard}
        <div id="plot-plot-3D"></div>
        {usingWalkthrough && clickedPointsCard}
      </Box>
      {!usingWalkthrough && (
        <Grid container spacing={2}>
          {/* Age TextField in its own row */}
          <Grid item xs={6}>
            <TextField
              fullWidth
              label="Age (Years)"
              value={audiogram.age}
              onChange={handleAgeChange}
            />
          </Grid>
          {/* Empty Grid item to balance the layout */}
          <Grid item xs={6}></Grid>

          {/* Frequency TextFields in subsequent rows */}
          {frequencies.map((frequency, index) => (
            <Grid item xs={6} key={frequency}>
              <TextField
                fullWidth
                label={`${frequency >= 1000 ? frequency / 1000 + "k" : frequency} Hz`}
                value={audiogram.values[index]}
                onChange={handleAudiogramChange(index)}
              />
            </Grid>
          ))}

          {/* If the number of frequencies is odd, add an extra empty Grid item to balance the last row */}
          {frequencies.length % 2 !== 0 && <Grid item xs={6}></Grid>}
          <Grid item xs={12}>
            <Button
              sx={{
                marginTop: 1,
                width: "15%",
                color: "white",
                backgroundColor: "black",
                "&:hover": {
                  backgroundColor: "white",
                  color: "black",
                },
              }}
              onClick={resetGraph}
            >
              Reset Graph
            </Button>
          </Grid>
        </Grid>
      )}
    </Box>
  );
};

export default SurfaceViewer;
