import React, { ChangeEvent, useEffect, useState } from "react";
import "./App.css";
import {
  Backdrop,
  Box,
  Button,
  CircularProgress,
  Divider,
  Grid,
  IconButton,
  InputBase,
  Paper,
  Table,
  TableBody,
  TableCell,
  TableHead,
  TableRow,
  TextField,
} from "@mui/material";
// import SendIcon from "@mui/icons-material/Send";
import SendSharpIcon from "@mui/icons-material/SendSharp";
import "@tensorflow/tfjs-backend-webgl";
import { questions } from "./answers";

import {
  BarChart,
  Bar,
  XAxis,
  YAxis,
  CartesianGrid,
  Tooltip,
  Legend,
} from "recharts";
import chatBot from "./lib/chatGPT-handler";
import { ChatHistory, Message } from "./components/ChatHistory";
//https://github.com/tensorflow/tfjs-models/tree/master/universal-sentence-encoder
const use = require("@tensorflow-models/universal-sentence-encoder");

const cosinesim = (A: number[], B: number[]): number => {
  var dotproduct = 0;
  var mA = 0;
  var mB = 0;

  for (var i = 0; i < A.length; i++) {
    dotproduct += A[i] * B[i];
    mA += A[i] * A[i];
    mB += B[i] * B[i];
  }

  mA = Math.sqrt(mA);
  mB = Math.sqrt(mB);
  var similarity = dotproduct / (mA * mB);

  return similarity;
};

let allCorrectAnswerEmbeddings: number[][][] = [];

//in a box that takes up the whole screen, put a button that says "Click Me"
function App() {
  const [key, setKey] = useState(0);
  const [userInput, setUserInput] = useState<string>("");
  const [chatResponse, setChatResponse] = useState<string>("");
  const [scores, setScores] = useState([0, 0, 0, 0, 0]);
  const [pulls, setPulls] = useState([0, 0, 0, 0, 0]);
  const [epsilon, setEpsilon] = useState(0.8);
  const [probabilities, setProbabilities] = useState([
    { name: "1", value: 20 },
    { name: "2", value: 20 },
    { name: "3", value: 20 },
    { name: "4", value: 20 },
    { name: "5", value: 20 },
  ]);
  const [chosenCategory, setChosenCategory] = useState(
    Math.floor(Math.random() * scores.length)
  );
  const [useModel, setUseModel] = useState<any>(null);
  const [answerScore, setAnswerScore] = useState<number | undefined>(undefined);
  const [openBackdrop, setOpenBackdrop] = useState(true);
  const [backdropMsg, setBackdropMsg] = useState("");
  const [chatHistory, setChatHistory] = useState<Message[]>([
    {
      sender: "bot",
      text: "Hi there! Enter your answer below and I'll tell you how well you did!",
    },
    // Add more messages as needed
  ]);

  const handleSubmission = async () => {
    console.log(useModel);
    console.log(allCorrectAnswerEmbeddings);
    setChatHistory((prevChatHistory) => [
      ...prevChatHistory,
      { sender: "user", text: userInput.toLowerCase() },
    ]);
    let maxSimilarity = 1.0;
    let maxSimilarityIndex = 0;
    if (useModel) {
      const allCorrectAnswerEmbeddings: number[][] = [];
      questions[chosenCategory].answer.map(async (answer) => {
        const embeddings = await useModel.embed(answer);
        allCorrectAnswerEmbeddings.push(embeddings.arraySync()[0]);
      });
      const embeddings = await useModel.embed(userInput.toLowerCase());
      // embeddings.print(true);
      const allSimilarityScores: number[] = [];
      allCorrectAnswerEmbeddings.map((correctAnswerEmbeddings) => {
        allSimilarityScores.push(
          cosinesim(correctAnswerEmbeddings, embeddings.arraySync()[0])
        );
      });
      maxSimilarity = Math.max(...allSimilarityScores);
      //get the index of the max similarity score
      maxSimilarityIndex = allSimilarityScores.indexOf(maxSimilarity);
      console.log(
        `closest answer is "${questions[chosenCategory].answer[maxSimilarityIndex]}" with similarity score: ${maxSimilarity}`
      );
      setAnswerScore(maxSimilarity);
    }
    try {
      if (maxSimilarity < 0.97) {
        setBackdropMsg("Checking with the experts...");
        setOpenBackdrop(true);
        const response = await chatBot(
          userInput,
          questions[chosenCategory].answer[maxSimilarityIndex]
        );
        setChatHistory((prevChatHistory) => [
          ...prevChatHistory,
          { sender: "bot", text: response },
        ]);

        setChatResponse(response);
        setKey((prevKey) => prevKey + 1); // increment key each time a new message is received
      } else {
        setChatResponse("Well done, great answer!");
        setChatHistory((prevChatHistory) => [
          ...prevChatHistory,
          { sender: "bot", text: "Well done, great answer!" },
        ]);
      }
    } finally {
      setOpenBackdrop(false);
    }
  };
  const handleInputChange = (
    e: ChangeEvent<HTMLTextAreaElement | HTMLInputElement>
  ) => {
    setUserInput(e.target.value);
  };

  const handleSendClick = async () => {
    //TODO Check is answer is correct
    if (userInput !== "correct answer") {
      // replace "correct answer" with the correct answer
      handleWrong();
      // send message to ChatGPT
      const response = await chatBot("sailing vessel", "power driven vessel");
      setChatResponse(response);
      setKey((prevKey) => prevKey + 1); // increment key each time a new message is received
      setUserInput("");
    }
    setUserInput(""); // reset the input field
  };

  const handleWrong = () => {
    // update the scores
    const newScores = [...scores];

    newScores[chosenCategory] +=
      (1 - newScores[chosenCategory]) / pulls[chosenCategory];
    setScores(newScores);
    pullArm(newScores);
  };

  const handleRight = () => {
    // update the scores
    const newScores = [...scores];

    newScores[chosenCategory] +=
      (0 - newScores[chosenCategory]) / pulls[chosenCategory];
    setScores(newScores);
    pullArm(newScores);
  };

  const handleNext = () => {
    let score = answerScore ? 1 - answerScore : 1; //assume wrong answer if no answer

    const newScores = [...scores];

    newScores[chosenCategory] +=
      (score - newScores[chosenCategory]) / pulls[chosenCategory];
    setScores(newScores);
    pullArm(newScores);
  };

  const pullArm = (scores: number[]) => {
    const choice = epsilonGreedy(scores);
    setChosenCategory(choice);
    setAnswerScore(undefined);
    setUserInput("");
    setChatHistory([]);
  };

  // min and max are INCLUSIVE
  function getRandomNumber(min: number, max: number) {
    const randomBuffer = new Uint32Array(1);
    window.crypto.getRandomValues(randomBuffer);
    const randomNumber = randomBuffer[0] / (0xffffffff + 1); // Normalize to [0, 1)
    return Math.floor(randomNumber * (max - min + 1)) + min;
  }

  const allZeroScore = (scores: number[]): boolean => {
    for (let i = 0; i < scores.length; i++) {
      if (scores[i] !== 0) {
        return false;
      }
    }
    return true;
  };

  const epsilonGreedy = (scores: number[]): number => {
    console.log("scores", scores);
    let maxIndex = 0;
    let maxScore = 0;
    let random = Math.random();
    for (let i = 0; i < scores.length; i++) {
      if (scores[i] > maxScore) {
        maxScore = scores[i];
        maxIndex = i;
      }
    }
    let exploreProbability = (epsilon * 100) / 5;
    let exploitProbability = exploreProbability + 100 - epsilon * 100;
    if (allZeroScore(scores)) {
      exploreProbability = 100 / 5;
      exploitProbability = exploreProbability;
      random = 0;
    }
    const newProbabilities = [
      { name: "1", value: exploreProbability },
      { name: "2", value: exploreProbability },
      { name: "3", value: exploreProbability },
      { name: "4", value: exploreProbability },
      { name: "5", value: exploreProbability },
    ];
    newProbabilities[maxIndex].value = exploitProbability;
    setProbabilities(newProbabilities);
    let method = "exploit";
    if (random < epsilon) {
      // explore
      method = "explore";
      maxIndex = getRandomNumber(0, 4);
    } else {
      // exploit
    }
    console.log("choice for next round", method, ":", maxIndex);
    const newPulls = [...pulls];
    newPulls[maxIndex] += 1;
    console.log("pulls", pulls);
    setPulls(newPulls);
    return maxIndex;
  };

  const DataTable = () => {
    return (
      <Table size="small">
        <TableHead>
          <TableRow>
            <TableCell size="small">Category</TableCell>
            <TableCell size="small">Penalty</TableCell>
            <TableCell size="small">Counts</TableCell>
          </TableRow>
        </TableHead>
        <TableBody>
          {scores.map((score, index) => (
            <TableRow key={index + 1}>
              <TableCell size="small">{index + 1}</TableCell>
              <TableCell size="small">{score.toFixed(3)}</TableCell>
              <TableCell size="small">{pulls[index]}</TableCell>
            </TableRow>
          ))}
        </TableBody>
      </Table>
    );
  };

  useEffect(() => {
    const handleBeforeUnload = (e: BeforeUnloadEvent) => {
      e.preventDefault();
      e.returnValue = ""; // For Chrome compatibility
    };

    const handleUnload = () => {
      const confirmationMessage = "Are you sure you want to leave this page?";
      return confirmationMessage;
    };

    window.addEventListener("beforeunload", handleBeforeUnload);
    window.addEventListener("unload", handleUnload);

    return () => {
      window.removeEventListener("beforeunload", handleBeforeUnload);
      window.removeEventListener("unload", handleUnload);
    };
  }, []);

  useEffect(() => {
    // Code to be executed once on component render
    console.log("Component rendered.");
    const initialPulls = [0, 0, 0, 0, 0];
    initialPulls[chosenCategory] = 1;
    setPulls(initialPulls);
    console.log(questions);

    setBackdropMsg("Loading model...");
    use
      .load({ vocabUrl: "/vocab/bert_vocab.json" })
      .then(async (model: { embed: (arg0: any) => Promise<any> }) => {
        // Embed an array of sentences.
        const sentences = ["Hello.", "How are you?"];
        setUseModel(model);

        // test the model
        model
          .embed(sentences)
          .then((embeddings: { print: (arg0: boolean) => void }) => {
            // `embeddings` is a 2D tensor consisting of the 512-dimensional embeddings for each sentence.
            // So in this example `embeddings` has the shape [2, 512].
            embeddings.print(true /* verbose */);
          });

        setBackdropMsg("Loading questions...");
        if (true) {
          const asyncOperations = questions.map(
            async (question, questionIndex) => {
              const questionEmbeddings: number[][] = [];
              const innerAsyncOperations = question.answer.map(
                async (answer) => {
                  const embeddings = await model.embed(answer);
                  console.log(
                    "category",
                    questionIndex + 1,
                    "answer",
                    answer,
                    "embeddings",
                    embeddings.arraySync()
                  );
                  console.log(".");
                  questionEmbeddings.push(embeddings.arraySync()[0]);
                }
              );
              await Promise.all(innerAsyncOperations);
              allCorrectAnswerEmbeddings.push(questionEmbeddings);
            }
          );
          await Promise.all(asyncOperations);
        } else {
          // loading 1 at a time is definitely slower than all at once above
          for (let i = 0; i < questions.length; i++) {
            setBackdropMsg(`Loading question ${i + 1}...`);
            const question = questions[i];
            const questionEmbeddings: number[][] = [];
            for (let j = 0; j < question.answer.length; j++) {
              const answer = question.answer[j];
              const embeddings = await model.embed(answer);
              console.log(
                "category",
                i + 1,
                "answer",
                answer,
                "embeddings",
                embeddings.arraySync()
              );
              console.log(".");
              questionEmbeddings.push(embeddings.arraySync()[0]);
            }
            allCorrectAnswerEmbeddings.push(questionEmbeddings);
          }
        }

        setOpenBackdrop(false);
        console.log(allCorrectAnswerEmbeddings);
      });

    // Return a cleanup function if required (optional)
    return () => {
      // Code to be executed on component unmount (cleanup)
      console.log("Component unmounted.");
    };
  }, []);

  return (
    <div className="App">
      <Box
        display="flex"
        overflow={"auto"}
        alignItems="center"
        height="100vh" // Set the height to take up the entire viewport height, adjust as needed
        // padding="10px" // Apply 10px padding on all sides
        margin={0}
        paddingRight={0}
        boxSizing="border-box" // Ensure padding is included in the total width/height
        style={{ display: "flex", flexDirection: "column" }}
        bgcolor={"whitesmoke"}
      >
        <Box
          style={{ display: "flex", flexDirection: "column" }}
          justifyContent="center"
          // overflow={"auto"}
          alignItems="center"
          // bgcolor={"darkgrey"}
          padding={2}
        >
          <img
            src={questions[chosenCategory].imgPath}
            alt="test-image"
            width={300}
            height={300}
          />
          Category {chosenCategory + 1} (
          {answerScore ? answerScore.toFixed(2) : "---"})
          {/* <TypeWriter key={key} message={chatResponse} /> */}
          <ChatHistory messages={chatHistory} />
          <Paper
            component="form"
            sx={{
              p: "2px 4px",
              display: "flex",
              alignItems: "center",
              width: "100%",
            }}
          >
            <InputBase
              sx={{ ml: 1, flex: 1 }}
              placeholder="Your answer here..."
              inputProps={{ "aria-label": "Your answer here..." }}
              multiline
              value={userInput}
              onChange={handleInputChange}
              // handle enter key
              onKeyPress={(ev) => {
                if (ev.key === "Enter") {
                  ev.preventDefault();
                  handleSubmission();
                }
              }}
            />
            <Divider sx={{ height: 28, m: 0.5 }} orientation="vertical" />
            <IconButton
              color="primary"
              sx={{ p: "10px" }}
              aria-label="directions"
              onClick={handleSubmission}
              disabled={userInput.length === 0}
            >
              <SendSharpIcon />
            </IconButton>
          </Paper>
          <Box
            display={"flex"}
            flexDirection={"row"}
            gap={2}
            margin={2}
            justifyContent="center"
          >
            <Button variant="contained" color="secondary" onClick={handleWrong}>
              wrong{" "}
            </Button>
            <Button variant="contained" color="success" onClick={handleRight}>
              right{" "}
            </Button>
            <Button
              variant="contained"
              size="large"
              color="primary"
              onClick={handleNext}
              disabled={answerScore === undefined}
            >
              next{" "}
            </Button>
          </Box>
        </Box>
        <Box
          display={"flex"}
          flexDirection={"row"}
          gap={2}
          margin={0}
          justifyContent="center"
          alignContent={"center"}
        >
          <Grid container spacing={2}>
            <Grid item xs={12} sm={6}>
              <Box paddingX={5}>
                <DataTable />
              </Box>
            </Grid>
            <Grid item xs={12} sm={6}>
              <Box
                display={"flex"}
                flexDirection={"column"}
                gap={2}
                paddingX={5}
                margin={0}
                justifyContent="center"
                alignContent={"center"}
              >
                <strong>Question Category Probabilities (current) </strong>
                <BarChart width={300} height={200} data={probabilities}>
                  <CartesianGrid strokeDasharray="3 3" />
                  <XAxis dataKey="name" />
                  <YAxis
                    tickFormatter={(tick) => {
                      return `${tick}%`;
                    }}
                  />
                  <Tooltip />
                  <Bar dataKey="value" fill="#8884d8" />
                </BarChart>
              </Box>
            </Grid>
          </Grid>
        </Box>
      </Box>
      <Backdrop
        sx={{
          color: "#fff",
          backgroundColor: "rgba(0, 0, 0, 0.8)",
          zIndex: (theme) => theme.zIndex.drawer + 1,
        }}
        open={openBackdrop}
      >
        <Box gap={2}>
          <CircularProgress color="inherit" />
          {backdropMsg}
        </Box>
      </Backdrop>
    </div>
  );
}

export default App;
