import React, { useState, useEffect } from "react";
import Plot from "react-plotly.js";
import allGenes from "./genes.json";
import twoData from "./master_interp_fixed_revise_remove_unknown_loci_name_fix.json";
import {
  FormControl,
  InputLabel,
  Select,
  MenuItem,
  FormGroup,
  Checkbox,
  Box,
  Card,
  CardContent,
  Typography,
  Grid,
} from "@material-ui/core";
import { makeStyles } from "@material-ui/core/styles";
import ReactDOM from "react-dom";
import { styled } from "@mui/system";
import { useAudioProfile } from "../contexts/AudioProfileContext";

const frequencies = [
  "125 dB",
  "250 dB",
  "500 dB",
  "1000 dB",
  "1500 dB",
  "2000 dB",
  "3000 dB",
  "4000 dB",
  "6000 dB",
  "8000 dB",
];

const useStyles = makeStyles((theme) => ({
  formControl: {
    margin: theme.spacing(1),
    minWidth: "10%",
    maxWidth: "60%",
    height: "65%",
    alignContent: "center",
    textAlign: "left",
  },
  formControlRow: {
    margin: theme.spacing(1),
    alignContent: "center",
    width: "100%",
    textAlign: "center",
    display: "flex",
    flexDirection: "row",
    flexWrap: "nowrap",
  },
  container: {
    display: "flex",
    flexDirection: "column",
    alignItems: "center",
    overflow: "scroll",
  },
  card: {
    display: "flex",
    flexDirection: "column",
    margin: theme.spacing(2),
  },
  cardContent: {
    flexGrow: 1,
    display: "flex",
    flexDirection: "column",
    alignItems: "center",
    justifyContent: "center",
  },
  select: {
    marginBottom: theme.spacing(2),
    width: "100%",
    height: "65%",
    fontStyle: "italic",
  },
  plotContainer: {
    margin: "auto",
    display: "flex",
    flexDirection: "column",
    alignItems: "center",
    justifyContent: "center",
    minHeight: "400px", // Set a larger minimum height
    height: "auto", // Allow the container to expand based on content
    width: "100%", // Ensure the plot fills the container
  },
  title: {
    fontSize: 38,
    fontWeight: "bold",
    textAlign: "center",
    marginBottom: theme.spacing(2),
  },
  paragraph: {
    fontSize: 18,
    marginBottom: theme.spacing(2),
  },
}));

const StyledMenuItem = styled(MenuItem)(({ theme }) => ({
  "&.MuiMenuItem-root": {
    color: "black",
    backgroundColor: "white",
    fontStyle: "italic",
    "&:hover": {
      color: "gold",
      backgroundColor: "white",
    },
  },
  "&.MuiMenuItem-selected": {
    backgroundColor: "black",
    color: "white",
  },
  width: "100%",
}));

const TwoDPlot = ({ usingWalkthrough }) => {
  const { predictionData, processedTableData } = useAudioProfile();
  const classes = useStyles();
  const sortedAllGenes = [...allGenes].sort();
  let initialGene = sortedAllGenes[0];
  if (usingWalkthrough) {
    initialGene =
      predictionData.audiogenev4?.[0]?.genes[0] ||
      predictionData.audiogenev9?.[0]?.genes[0] ||
      initialGene;
  }

  const [selectedGene, setSelectedGene] = useState(initialGene);
  const [selectedId, setSelectedId] = useState(null);
  const [selectedAges, setSelectedAges] = useState([0, 1, 2, 3]);
  const [dimensions, setDimensions] = useState({
    width: window.innerWidth,
    height: window.innerHeight,
  });

  useEffect(() => {
    const handleResize = () => {
      setDimensions({
        width: window.innerWidth,
        height: window.innerHeight,
      });
    };

    window.addEventListener("resize", handleResize);
    return () => {
      window.removeEventListener("resize", handleResize);
    };
  }, []);

  useEffect(() => {
    generate2DPlot(selectedGene, 0, true);
  }, [processedTableData, selectedGene, selectedAges, selectedId, dimensions]);

  useEffect(() => {
    if (sortedAllGenes.length > 0) {
      sortedAllGenes.forEach((gene, index) => {
        if (document.getElementById(`plot-plot-2D-${index}`)) {
          generate2DPlot(gene, index, false);
        }
      });
    }
  }, [sortedAllGenes]);

  const generate2DPlot = (gene, index, usingDashboard) => {
    const meanValues = calculateMeanValues(twoData, frequencies);
    const filteredTwoData = twoData.filter((item) => item.locus === gene);
    const twoDataProcessed = processData(
      filteredTwoData,
      frequencies,
      meanValues,
      selectedAges,
    );

    let tableDataProcessed = [];
    if (usingWalkthrough) {
      tableDataProcessed = processTableData(
        processedTableData,
        frequencies,
        selectedId,
      );
    }

    renderPlot(
      [...twoDataProcessed, ...tableDataProcessed],
      `plot-plot-2D-${index}`,
      gene,
      usingDashboard,
    );
  };

  const clipValue = (value, min, max) => {
    return Math.min(Math.max(value, min), max);
  };

  const colorBlindFriendlyPalette = [
    "#F3C677",
    "#F1924E",
    "#5A4FA2",
    "#00204C",
  ];

  const processData = (data, frequencies, meanValues, selectedAges) => {
    const ageRanges = [[], [], [], []];
    data.forEach((item) => {
      const ageIndex = Math.min(Math.floor(item.age / 20), 3);
      if (selectedAges.includes(ageIndex)) {
        ageRanges[ageIndex].push(item);
      }
    });

    return ageRanges
      .map((range, index) => {
        if (range.length === 0 || !selectedAges.includes(index)) {
          return null;
        }
        const nameWithCount = `Ages ${index * 20}-${(index + 1) * 20 - 1} Count=${range.length}`;
        const xValues = frequencies.map((freq) => freq.split(" ")[0]);
        const yValues = frequencies.map((freq) => {
          const values = range.map((item) =>
            item[freq] !== undefined
              ? clipValue(parseFloat(item[freq]), -10, 120)
              : clipValue(meanValues[freq], -10, 120),
          );
          return values.reduce((a, b) => a + b, 0) / values.length;
        });
        const hoverTexts = xValues.map(
          (x, i) => `Frequency: ${x} Hz<br>dB HL: ${yValues[i].toFixed(2)}`,
        );
        return {
          x: xValues,
          y: yValues,
          mode: "lines+markers",
          name: nameWithCount,
          text: hoverTexts,
          hoverinfo: "text",
          line: {
            color:
              colorBlindFriendlyPalette[
                index % colorBlindFriendlyPalette.length
              ],
          },
          marker: {
            color:
              colorBlindFriendlyPalette[
                index % colorBlindFriendlyPalette.length
              ],
          },
        };
      })
      .filter((trace) => trace !== null);
  };

  const calculateMeanValues = (data, frequencies) => {
    const sums = frequencies.reduce((acc, freq) => ({ ...acc, [freq]: 0 }), {});
    const counts = frequencies.reduce(
      (acc, freq) => ({ ...acc, [freq]: 0 }),
      {},
    );

    data.forEach((item) => {
      frequencies.forEach((freq) => {
        if (item[freq] !== undefined && item[freq] !== null) {
          sums[freq] += parseFloat(item[freq]);
          counts[freq]++;
        }
      });
    });

    return frequencies.reduce((acc, freq) => {
      acc[freq] = counts[freq] ? sums[freq] / counts[freq] : 0;
      return acc;
    }, {});
  };

  const processTableData = (data, frequencies, selectedId) => {
    if (!selectedId) return [];

    const rows = data.filter((row) => row.id === selectedId);
    if (rows.length === 0) return [];

    let traces = [];

    rows.forEach((row, index) => {
      let plotData = frequencies
        .map((freq) => {
          if (row[freq] !== undefined) {
            return {
              x: freq.split(" ")[0],
              y: clipValue(parseFloat(row[freq]), -10, 120),
            };
          }
          return null;
        })
        .filter((point) => point !== null);

      if (plotData.length > 0) {
        const trace = {
          x: plotData.map((point) => point.x),
          y: plotData.map((point) => point.y),
          mode: "lines+markers",
          type: "scatter",
          name: `Patient ${selectedId} - Age ${row.age}`,
          line: { color: "red" },
          marker: { color: "red" },
        };

        traces.push(trace);
      }
    });

    return traces;
  };

  const renderPlot = (traces, containerId, gene, usingDashboard) => {
    const container = document.getElementById(containerId);
    if (!container) return;

    // Clear the previous plot if it exists
    ReactDOM.unmountComponentAtNode(container);

    // Get the dimensions of the container
    const rect = container.getBoundingClientRect();

    const layout = {
      responsive: true,
      title: {
        text: `AudioProfile for <i>${gene}</i>`,
        //text: `<i>${gene}</i>`,  alternative title
        font: {
          size: 25,
        },
        y: usingDashboard ? 0.96 : 0.77, // Position the title lower relative to the top of the plot area
        x: 0.43, // Center the title horizontally
        xanchor: "center",
        yanchor: "bottom",
      },
      xaxis: {
        title: "Frequency (Hz)",
        type: "log",
        autorange: true,
        tickvals: [125, 250, 500, 1000, 2000, 4000, 8000],
        ticktext: ["125", "250", "500", "1k", "2k", "4k", "8k"],
        tickfont: {
          size: 14,
        },
      },
      yaxis: {
        title: "Hearing Level (dB HL)",
        type: "linear",
        autorange: false,
        range: [130, -10],
        tickvals: [
          -10, 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130,
        ],
        ticktext: [
          "-10",
          "0",
          "10",
          "20",
          "30",
          "40",
          "50",
          "60",
          "70",
          "80",
          "90",
          "100",
          "110",
          "120",
          "130",
        ],
        tickfont: {
          size: 14,
        },
      },
      shapes: [
        // Dotted lines for specific frequencies
        {type: 'line', x0: 1500, x1: 1500, y0: -10, y1: 130, xref: 'x', yref: 'y', line: {color: 'lightgrey', width: 1, dash: 'dot'}, layer: 'below'},
        {type: 'line', x0: 3000, x1: 3000, y0: -10, y1: 130, xref: 'x', yref: 'y', line: {color: 'lightgrey', width: 1, dash: 'dot'}, layer: 'below'},
        {type: 'line', x0: 6000, x1: 6000, y0: -10, y1: 130, xref: 'x', yref: 'y', line: {color: 'lightgrey', width: 1, dash: 'dot'}, layer: 'below'},
      ],
      legend: {
        font: {
          size: 10, // Smaller font size for legend text
        },
        orientation: "h", // Horizontal layout
        x: 0.5, // Centered horizontally
        y: 1.1, // Positioned above the top of the plot area
        xanchor: "center", // Anchor the legend at its center
        yanchor: "bottom", // Anchor the legend at its bottom
      },
      margin: { l: 40, r: 100, b: 50, t: 225 },
      paper_bgcolor: "white",
      plot_bgcolor: "white",
      hovermode: "closest",
      grid: { rows: 1, columns: 1, pattern: "independent" },
      width: rect.width, // Set width based on container size
      height: Math.max(rect.width, 560), // Set height based on container size
    };

    traces.forEach((trace) => {
      trace.mode = "lines+markers";
    });

    ReactDOM.render(<Plot data={traces} layout={layout} />, container);
  };

  const handleAgeChange = (event) => {
    setSelectedAges(event.target.value);
  };

  const handleIdChange = (event) => {
    setSelectedId(event.target.value);
  };

  const ids = processedTableData
    ? Array.from(new Set(processedTableData.map((row) => row.id)))
    : [];

  if (usingWalkthrough) {
    return (
      <Box padding={2} className={classes.container}>
        <FormGroup row className={classes.formControlRow}>
          <FormControl
            variant="outlined"
            margin="normal"
            className={classes.formControl}
          >
            <InputLabel id="gene-select-label">Gene</InputLabel>
            <Select
              labelId="gene-select-label"
              id="gene-select"
              value={selectedGene}
              onChange={(e) => setSelectedGene(e.target.value)}
              label="Gene"
              className={classes.select}
            >
              {sortedAllGenes.map((gene) => (
                <StyledMenuItem key={gene} value={gene}>
                  {gene}
                </StyledMenuItem>
              ))}
            </Select>
          </FormControl>
          {processedTableData && (
            <FormControl
              variant="outlined"
              margin="normal"
              className={classes.formControl}
            >
              <InputLabel id="age-range-label">Age Ranges</InputLabel>
              <Select
                labelId="age-range-label"
                id="age-range-select"
                multiple
                value={selectedAges}
                onChange={handleAgeChange}
                className={classes.select}
                renderValue={(selected) =>
                  selected
                    .map((age) => `Ages ${age * 20}-${(age + 1) * 20 - 1}`)
                    .join(", ")
                }
              >
                {Array.from({ length: 4 }, (_, i) => i).map((age) => (
                  <StyledMenuItem key={age} value={age}>
                    <Checkbox checked={selectedAges.includes(age)} />
                    {`Ages ${age * 20}-${(age + 1) * 20 - 1}`}
                  </StyledMenuItem>
                ))}
              </Select>
            </FormControl>
          )}
          {processedTableData && (
            <FormControl
              variant="outlined"
              margin="normal"
              className={classes.formControl}
            >
              <InputLabel id="id-select-label">ID</InputLabel>
              <Select
                labelId="id-select-label"
                id="id-select"
                value={selectedId}
                onChange={handleIdChange}
                label="ID"
              >
                <MenuItem value={null}>None</MenuItem>
                {ids.map((id) => (
                  <MenuItem key={id} value={id}>
                    {id}
                  </MenuItem>
                ))}
              </Select>
            </FormControl>
          )}
        </FormGroup>
        <div
          className={classes.plotContainer}
          style={{
            minWidth: usingWalkthrough ? "500px" : "60vw",
            minHeight: usingWalkthrough ? "700px" : "60vh",
          }}
          id="plot-plot-2D-0"
        ></div>
      </Box>
    );
  } else {
    return (
      <Card className={classes.card}>
        <CardContent className={classes.cardContent}>
          <Typography variant="h4" component="h2" className={classes.title}>
            Audio Profiles
          </Typography>
          <Typography
            variant="body2"
            color="textSecondary"
            component="p"
            className={classes.paragraph}
          >
            Audioprofiles are average audiograms based on all audiograms in our
            dataset for each locus grouped in two-decade increments.
          </Typography>
          <Grid container spacing={2}>
            {sortedAllGenes.map((gene, index) => (
              <Grid item xs={12} sm={6} md={4} key={gene}>
                <div
                  className={classes.plotContainer}
                  style={{
                    maxWidth: dimensions.width <= 900 ? "700px" : "85vw",
                    minWidth: "450px",
                    maxHeight: dimensions.height <= 700 ? "700px" : "85vh",
                    minHeight: "500px",
                  }}
                  id={`plot-plot-2D-${index}`}
                ></div>
              </Grid>
            ))}
          </Grid>
        </CardContent>
      </Card>
    );
  }
};

export default TwoDPlot;
