// S3FileViewer.tsx
import React, { useEffect, useState } from "react";
import {
  fetchFiles,
  deleteFile,
  getDownloadUrl,
  S3File,
} from "../../../../../components/S3/S3Utils";
import {
  Paper,
  Table,
  TableBody,
  TableCell,
  TableContainer,
  TableHead,
  TableRow,
  Grid,
  Tabs,
  Tab,
  CircularProgress,
  FormControl,
  Alert,
  DialogContentText,
  DialogTitle,
  Button,
  Box,
  List,
  ListItemButton,
  ListItemIcon,
  ListItemText,
} from "@mui/material";
import DeleteIcon from "@mui/icons-material/Delete";
import DownloadIcon from "@mui/icons-material/Download";
import { readJsonFromS3 } from "../../../../../components/S3/S3Utils";
import Plot from "react-plotly.js";
import { Data } from "plotly.js";

interface AUCProps {
  auc?: { [key: string]: number };
  tpr?: { [key: string]: number[] };
  fpr?: { [key: string]: number[] };
  thresholds?: { [key: string]: number[] };
  class_names?: string[];
}

interface PRProps {
  f1?: { [key: string]: number };
  precision?: { [key: string]: number[] };
  recall?: { [key: string]: number[] };
  thresholds?: { [key: string]: number[] };
  class_names?: string[];
}

interface ConfusionMatrixProps {
  values?: number[][];
  labels?: string[];
}

interface MetricsProps {
  metrics: { [key: string]: any };
}

interface VisualizationPlotProps {
  auc?: AUCProps;
  pr?: PRProps;
  confusion_matrix?: ConfusionMatrixProps;
}

interface ApiResponse {
  task_type?: string | null;
  out_location?: string | null;
  visualization_plots?: VisualizationPlotProps;
  explanation_plots?: any;
  metrics?: MetricsProps;
  status?: string | null;
}

interface TrainResponseViewerProps {
  refresh: any;
  bucketName: string;
  pathPrefix: string;
  serviceName: string;
}

const ClassificationResponseViewer: React.FC<TrainResponseViewerProps> = ({
  refresh,
  bucketName,
  pathPrefix,
  serviceName,
}) => {
  const [files, setFiles] = useState<S3File[]>([]);

  const [isLoading, setIsLoading] = useState(false);

  const [activeTab, setActiveTab] = React.useState<number>(0);

  const handleChange = (event: React.SyntheticEvent, newValue: number) => {
    setActiveTab(newValue);
  };

  useEffect(() => {
    const loadFiles = async () => {
      try {
        const fetchedFiles = await fetchFiles(
          (bucketName = bucketName),
          (pathPrefix = pathPrefix + "/responses/" + serviceName),
          ["json"]
        );
        setFiles(fetchedFiles);
      } catch (error) {
        console.error("Failed to fetch files:", error);
      }
    };
    loadFiles();
  }, [refresh, bucketName, pathPrefix]);

  // ROC-AUC Curve

  interface ROCProps {
    rocData?: AUCProps;
  }

  const ROCPlot: React.FC<ROCProps> = ({ rocData }) => {
    if (!rocData) {
      return <div>ROC Plots are not initialized!</div>;
    }
    if (!rocData.tpr || !rocData.fpr || !rocData.thresholds || !rocData.auc) {
      return <div>ROC Plots are not initialized!</div>;
    }
    if (!rocData.thresholds) {
      return <div>ROC Plots are not initialized!</div>;
    }
    const { tpr, fpr, auc } = rocData;

    // Create data traces for each class
    const data: Partial<Data>[] = Object.keys(tpr).map((key) => ({
      x: fpr[key],
      y: tpr[key],
      type: "scatter" as const,
      mode: "lines" as const,
      name: `${key} (AUC: ${auc[key].toFixed(3)})`,
      // text: `Thresholds: ${rocData?.thresholds[key].map((threshold) =>
      //   threshold.toFixed(2)
      // )}`,
      hoverinfo: "x+y+text" as const,
    }));

    return (
      <Plot
        data={data}
        layout={{
          title: "ROC Curves",
          xaxis: { title: "False Positive Rate", range: [0, 1] },
          yaxis: { title: "True Positive Rate", range: [0, 1] },
          legend: { title: { text: "Class Labels with AUC" } },
        }}
        style={{ width: "100%", height: "100%" }}
      />
    );
  };

  /// PR Curve

  interface PRPlotProps {
    prData?: PRProps;
  }

  const PRCurve: React.FC<PRPlotProps> = ({ prData }) => {
    if (!prData) {
      return <div>PR Plots are not initialized!</div>;
    }
    if (!prData.precision || !prData.recall || !prData.f1) {
      return <div>ROC Plots are not initialized!</div>;
    }
    const { precision, recall, f1 } = prData;

    // Create data traces for each class
    const data = Object.keys(precision).map((key) => ({
      x: recall[key],
      y: precision[key],
      type: "scatter" as const,
      mode: "lines+markers" as const,
      name: `${key} (F1: ${f1[key].toFixed(3)})`,
      hoverinfo: "x+y" as const,
      // text: prData.thresholds[key].map(threshold => `Threshold: ${threshold.toFixed(2)}`),
      marker: { size: 8 },
    }));

    return (
      <Plot
        data={data}
        layout={{
          title: "Precision-Recall Curve",
          xaxis: { title: "Recall" },
          yaxis: { title: "Precision", range: [0, 1] },
          legend: { title: { text: "Classes with F1 Scores" } },
        }}
        style={{ width: "100%", height: "100%" }}
      />
    );
  };

  /// Confusion Matrix

  interface ConfusionMatrixPlotProps {
    confMatrix?: ConfusionMatrixProps;
  }

  const ConfusionMatrix: React.FC<ConfusionMatrixPlotProps> = ({
    confMatrix,
  }) => {
    if (!confMatrix) {
      return <div>Confusion Matrix is not initialized! 1</div>;
    }
    if (!confMatrix.values || !confMatrix.labels) {
      return <div>{JSON.stringify(confMatrix)}</div>;
    }
    return (
      <Plot
        data={[
          {
            z: confMatrix.values,
            x: confMatrix.labels,
            y: confMatrix.labels,
            type: "heatmap" as const,
            colorscale: "Viridis" as const,
            showscale: true, // Toggle to enable/disable the color scale bar
            hoverinfo: "z+x+y" as const, // Displays the cell value and row/column labels on hover
          },
        ]}
        layout={{
          title: "Confusion Matrix",
          xaxis: { title: "Predicted Label", tickangle: -45, automargin: true },
          yaxis: { title: "True Label", automargin: true },
          width: 700, // Adjust the width and height to suit your needs
          height: 700,
          margin: { l: 100, r: 100, b: 100, t: 100 }, // Adjust margins to fit labels if necessary
        }}
        config={{
          responsive: true,
        }}
        style={{ width: "100%", height: "100%" }}
      />
    );
  };

  // Metrics Table
  interface MetricsTableProps {
    metrics?: MetricsProps;
  }

  const MetricsTable: React.FC<MetricsTableProps> = ({ metrics }) => {
    if (!metrics) {
      return <div> Metrics are not available! </div>;
    }

    if (Object.keys(metrics)[1] === "Model") {
      const models = Object.values(Object.values(metrics)[1]) as string[];
      // console.log(models);
      return (
        <TableContainer component={Paper}>
          <Table sx={{ minWidth: 650 }} aria-label="simple table">
            <TableHead>
              <TableRow>
                <TableCell sx={{ fontWeight: "bold" }}>Metric</TableCell>
                {models.map((value) => {
                  return (
                    <TableCell sx={{ fontWeight: "bold" }}>{value}</TableCell>
                  );
                })}
              </TableRow>
            </TableHead>
            <TableBody>
              {Object.entries(metrics).map(([key, value]) => {
                // console.log(key);
                // console.log(key === "Model" || key === "Unnamed: 0");
                if (key === "Model" || key === "Unnamed: 0") {
                  return null;
                }
                if (typeof value === "object") {
                  const rows = Object.keys(value).map((kkey) => ({
                    id: kkey,
                    acc: value[kkey],
                  }));
                  return (
                    <TableRow
                      key={key}
                      sx={{ "&:last-child td, &:last-child th": { border: 0 } }}
                    >
                      <TableCell component="th" scope="row">
                        {key}
                      </TableCell>
                      {rows.map((row) => (
                        <TableCell component="th" scope="row">
                          {row.acc}
                        </TableCell>
                      ))}
                    </TableRow>
                  );
                }
                return null;
              })}
            </TableBody>
          </Table>
        </TableContainer>
      );
    }
    return (
      <TableContainer component={Paper}>
        <Table sx={{ minWidth: 650 }} aria-label="simple table">
          <TableHead>
            <TableRow>
              <TableCell sx={{ fontWeight: "bold" }}>Metric</TableCell>
              <TableCell sx={{ fontWeight: "bold" }}>Value</TableCell>
            </TableRow>
          </TableHead>
          <TableBody>
            {Object.entries(metrics).map(([key, value]) => {
              if (typeof value === "number") {
                return (
                  <TableRow
                    key={key}
                    sx={{ "&:last-child td, &:last-child th": { border: 0 } }}
                  >
                    <TableCell component="th" scope="row">
                      {key}
                    </TableCell>
                    <TableCell>
                      {value === null ? "null" : value.toString()}
                    </TableCell>
                  </TableRow>
                );
              }
              return null;
            })}
          </TableBody>
        </Table>
      </TableContainer>
    );
  };

  const [selectedIndex, setSelectedIndex] = React.useState(0);
  const handleListItemClick = (
    event: React.MouseEvent<HTMLDivElement, MouseEvent>,
    index: number
  ) => {
    setSelectedIndex(index);
  };

  const [apiResponse, setApiResponse] = useState<ApiResponse>(
    files.length > 0
      ? (readJsonFromS3(bucketName, files[selectedIndex].Key) as ApiResponse)
      : ({} as ApiResponse)
  );

  useEffect(() => {
    const fetchData = async () => {
      try {
        if (selectedIndex >= 0 && selectedIndex < files.length) {
          setIsLoading(true);
          const response = await readJsonFromS3(
            bucketName,
            files[selectedIndex].Key
          );
          setApiResponse(response as ApiResponse);
        }
      } catch (error) {
        console.error("Got an error: ", error);
      }
      setIsLoading(false);
    };

    fetchData();
  }, [selectedIndex, files]);

  return (
    <div>
      <Grid container spacing={3}>
        <Grid item xs={12} sm={6}>
          <Paper style={{ maxHeight: 400, overflow: "auto" }}>
            <List
              component="nav"
              aria-label="main mailbox foldersfinished-jobs"
              title="List of Train Jobs"
            >
              {files.map((file, index) => (
                <ListItemButton
                  selected={selectedIndex === index}
                  onClick={(event) => handleListItemClick(event, index)}
                >
                  <ListItemText
                    primary={file.Key.replace(
                      pathPrefix + "/responses/" + serviceName + "/",
                      ""
                    )}
                  />
                </ListItemButton>
              ))}
            </List>
          </Paper>
        </Grid>
        <Grid item xs={12} sm={6}>
          {/* <TableContainer component={Paper}>
            <TableBody>
              <TableRow>
                <TableCell>Out S3 File Name</TableCell>
                <TableCell>
                  {apiResponse.out_location?.substring(
                    apiResponse.out_location?.lastIndexOf("/") + 1
                  )}
                </TableCell>
              </TableRow>
              <TableRow>
                <TableCell>Out Log ID</TableCell>
                <TableCell>
                  {apiResponse.out_log_id ? apiResponse.out_log_id : "N/A"}
                </TableCell>
              </TableRow>
              <TableRow>
                <TableCell>Out LithoLens Generic Model ID</TableCell>
                <TableCell>
                  {apiResponse.out_generic_model_id
                    ? apiResponse.out_generic_model_id
                    : "N/A"}
                </TableCell>
              </TableRow>
              <TableRow>
                <TableCell>Warnings</TableCell>
                <TableCell>
                  {apiResponse.warnings
                    ? apiResponse.warnings.length > 0
                      ? apiResponse.warnings
                      : "N/A"
                    : "N/A"}
                </TableCell>
              </TableRow>
              <TableRow>
                <TableCell>Error</TableCell>
                <TableCell>
                  {apiResponse.error ? apiResponse.error : "N/A"}
                </TableCell>
              </TableRow>
            </TableBody>
          </TableContainer> */}
        </Grid>
      </Grid>
      <Tabs
        variant="scrollable"
        scrollButtons="auto"
        value={activeTab}
        onChange={handleChange}
        aria-label="file tabs"
        sx={{
          backgroundColor: "white", // Light teal background for the whole tabs bar
          boxShadow: "0 2px 4px rgba(0,0,0,0.1)", // Adding a subtle shadow under the tabs bar
          "& .MuiTabs-flexContainer": {
            gap: "10px", // Adds space between each tab/button
          },
        }}
      >
        <Tab
          key="metrics"
          label="Metrics Table"
          id="metrics"
          aria-controls={`tabpanel-metrics`}
        />
        <Tab
          key="roc"
          label="ROC Curve"
          id="roc"
          aria-controls={`tabpanel-roc`}
        />
        <Tab key="pr" label="PR Curve" id="pr" aria-controls={`tabpanel-pr`} />
        <Tab
          key="confusion"
          label="Confusion Matrix"
          id="confusion"
          aria-controls={`tabpanel-confusion`}
        />
      </Tabs>
      {isLoading ? (
        <div style={{ textAlign: "center", margin: "15px" }}>
          <CircularProgress />
        </div>
      ) : apiResponse?.status === "requested" ? (
        <div>
          <Alert severity="warning">Job in progress</Alert>
        </div>
      ) : apiResponse?.status === "failed" ? (
        <div>
          <Alert severity="error">Job failed!</Alert>
        </div>
      ) : (
        <div>
          <div
            hidden={activeTab !== 0}
            id="roc-curve"
            style={{ flex: "auto", textAlign: "center", margin: "20px" }}
          >
            <MetricsTable metrics={apiResponse?.metrics} />
          </div>
          <div
            hidden={activeTab !== 1}
            id="roc-curve"
            style={{ flex: "auto", textAlign: "center", margin: "20px" }}
          >
            <ROCPlot rocData={apiResponse.visualization_plots?.auc} />
          </div>
          <div
            hidden={activeTab !== 2}
            id="pr-curve"
            style={{ flex: "auto", textAlign: "center", margin: "20px" }}
          >
            <PRCurve prData={apiResponse.visualization_plots?.pr} />
          </div>
          <div
            hidden={activeTab !== 3}
            id="confusion-matrix"
            style={{ flex: "auto", textAlign: "center", margin: "20px" }}
          >
            <ConfusionMatrix
              confMatrix={apiResponse.visualization_plots?.confusion_matrix}
            />
          </div>
        </div>
      )}
    </div>
  );
};

export default ClassificationResponseViewer;
