import npyjs from "npyjs";
import { InferenceSession, Tensor } from "onnxruntime-web";
import React, { Component } from "react";
import { Spinner } from "react-bootstrap";
import { IoCaretBackCircle, IoPlayBackCircle } from "react-icons/io5";
import { MdDelete } from "react-icons/md";
import { getImageEmbeddings, readVariation, updateVariation } from "../api";
import { colors } from "../colors";
import { fontSizes } from "../fontSizes";
import { fontWeights } from "../fontWeights";
import Button from "./Button";

const ort = require("onnxruntime-web");
function arrayToImageData(input, width, height) {
  const [r, g, b, a] = [0, 114, 189, 255]; // the masks's blue color
  const arr = new Uint8ClampedArray(4 * width * height).fill(0);
  for (let i = 0; i < input.length; i++) {
    // Threshold the onnx model mask prediction at 0.0
    // This is equivalent to thresholding the mask using predictor.model.mask_threshold
    // in python
    if (input[i] > 0.0) {
      arr[4 * i + 0] = r;
      arr[4 * i + 1] = g;
      arr[4 * i + 2] = b;
      arr[4 * i + 3] = a;
    }
  }
  return new ImageData(arr, height, width);
}

// Use a Canvas element to produce an image from ImageData
function imageDataToImage(imageData) {
  const canvas = imageDataToCanvas(imageData);
  const image = new Image();
  image.src = canvas.toDataURL();
  return image;
}

function combineMasks(masks, width, height) {
  // Initialize the combined array with black color for all pixels and full opacity for the alpha channel
  const combinedWithAlpha = new Uint8ClampedArray(4 * width * height).fill(0);
  for (let i = 0; i < width * height; i++) {
    combinedWithAlpha[4 * i + 3] = 255; // Set alpha channel to full opacity
  }
  const [r, g, b] = [255, 255, 255]; // the selected area's white color

  masks.forEach((mask) => {
    const canvas = document.createElement("canvas");
    const ctx = canvas.getContext("2d");
    canvas.width = width;
    canvas.height = height;

    // Draw the mask onto the canvas
    ctx.drawImage(mask, 0, 0, width, height);

    // Get the image data from the canvas
    const imageData = ctx.getImageData(0, 0, width, height);

    for (let i = 0; i < width * height; i++) {
      // If the mask has a non-zero alpha value at this pixel, set the combined mask to white
      if (imageData.data[4 * i + 3] > 0) {
        combinedWithAlpha[4 * i + 0] = r;
        combinedWithAlpha[4 * i + 1] = g;
        combinedWithAlpha[4 * i + 2] = b;
        // Alpha channel remains 255 (fully opaque)
      }
    }
  });

  // Create the combined ImageData
  const combinedImageData = new ImageData(combinedWithAlpha, width, height);

  // Use a canvas to convert the ImageData to a base64 string
  const canvas = document.createElement("canvas");
  const ctx = canvas.getContext("2d");
  canvas.width = width;
  canvas.height = height;
  ctx.putImageData(combinedImageData, 0, 0);

  return canvas.toDataURL();
}
// Canvas elements can be created from ImageData
function imageDataToCanvas(imageData) {
  const canvas = document.createElement("canvas");
  const ctx = canvas.getContext("2d");
  canvas.width = imageData.width;
  canvas.height = imageData.height;
  ctx?.putImageData(imageData, 0, 0);
  return canvas;
}

// Convert the onnx model mask output to an HTMLImageElement
export function onnxMaskToImage(input, width, height) {
  return imageDataToImage(arrayToImageData(input, width, height));
}

const modelData = ({ clicks, tensor, modelScale }) => {
  const imageEmbedding = tensor;
  let pointCoords;
  let pointLabels;
  let pointCoordsTensor;
  let pointLabelsTensor;

  // Check there are input click prompts
  if (clicks) {
    let n = clicks.length;

    // If there is no box input, a single padding point with
    // label -1 and coordinates (0.0, 0.0) should be concatenated
    // so initialize the array to support (n + 1) points.
    pointCoords = new Float32Array(2 * (n + 1));
    pointLabels = new Float32Array(n + 1);

    // Add clicks and scale to what SAM expects
    for (let i = 0; i < n; i++) {
      pointCoords[2 * i] = clicks[i].x * modelScale.samScale;
      pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale;
      pointLabels[i] = clicks[i].clickType;
    }

    // Add in the extra point/label when only clicks and no box
    // The extra point is at (0, 0) with label -1
    pointCoords[2 * n] = 0.0;
    pointCoords[2 * n + 1] = 0.0;
    pointLabels[n] = -1.0;

    // Create the tensor
    pointCoordsTensor = new Tensor("float32", pointCoords, [1, n + 1, 2]);
    pointLabelsTensor = new Tensor("float32", pointLabels, [1, n + 1]);
  }
  const imageSizeTensor = new Tensor("float32", [
    modelScale.originalHeight,
    modelScale.originalWidth,
  ]);

  if (pointCoordsTensor === undefined || pointLabelsTensor === undefined)
    return;

  // There is no previous mask, so default to an empty tensor
  const maskInput = new Tensor(
    "float32",
    new Float32Array(256 * 256),
    [1, 1, 256, 256]
  );
  // There is no previous mask, so default to 0
  const hasMaskInput = new Tensor("float32", [0]);

  return {
    image_embeddings: imageEmbedding,
    point_coords: pointCoordsTensor,
    point_labels: pointLabelsTensor,
    orig_im_size: imageSizeTensor,
    mask_input: maskInput,
    has_mask_input: hasMaskInput,
  };
};

const handleImageScale = (image) => {
  // Input images to SAM must be resized so the longest side is 1024
  const LONG_SIDE_LENGTH = 1024;
  let w = image.width;
  let h = image.height;

  // Calculate the aspect ratio of the image
  const imageAspectRatio = image.naturalWidth / image.naturalHeight;

  // Calculate the aspect ratio of the container
  const containerAspectRatio = w / h;

  let renderedWidth, renderedHeight;

  // Check if the image is letterboxed or pillarboxed
  if (imageAspectRatio > containerAspectRatio) {
    // Image is letterboxed. It fills the width of the container and is centered vertically.
    renderedWidth = w;
    renderedHeight = w / imageAspectRatio;
  } else {
    // Image is pillarboxed. It fills the height of the container and is centered horizontally.
    renderedHeight = h;
    renderedWidth = h * imageAspectRatio;
  }

  const samScale = LONG_SIDE_LENGTH / Math.max(renderedHeight, renderedWidth);

  return {
    height: renderedHeight,
    width: renderedWidth,
    samScale,
    originalHeight: image.naturalHeight,
    originalWidth: image.naturalWidth,
  };
};
class ZoneVariation extends Component {
  constructor(props) {
    super(props);
    this.state = { loadingEmbeddings: true, embeddings: null, masks: [] };
  }

  initModel = async () => {
    try {
      const model = await InferenceSession.create(
        process.env.PUBLIC_URL + "/models/export.onnx"
      );
      this.setState({ model: model });
    } catch (e) {
      console.log(e);
    }
  };
  loadFromArray = (array, dType) => {
    const tensor = new ort.Tensor(dType, array.data, array.shape);
    return tensor;
  };
  loadNpyTensor = async (tensorFile, dType) => {
    let npLoader = new npyjs();
    const npArray = await npLoader.load(tensorFile);
    const tensor = new ort.Tensor(dType, npArray.data, npArray.shape);
    return tensor;
  };

  handleImageClick = async (event) => {
    // Get the image's bounding rectangle
    const rect = this.image.getBoundingClientRect();

    // Calculate the click position relative to the div
    const x = event.clientX - rect.left;
    const y = event.clientY - rect.top;

    // Calculate the aspect ratios of the image and the div
    const imageAspectRatio = this.image.naturalWidth / this.image.naturalHeight;
    const divAspectRatio = rect.width / rect.height;

    let imageX, imageY;

    // Check if the image is letterboxed or pillarboxed
    if (imageAspectRatio > divAspectRatio) {
      // Image is letterboxed. It fills the width of the div and is centered vertically.
      const scaleFactor = rect.width / this.image.naturalWidth;
      const imageHeight = this.image.naturalHeight * scaleFactor;
      const verticalOffset = (rect.height - imageHeight) / 2;

      // Calculate the click position relative to the image
      imageX = x;
      imageY = y - verticalOffset;
    } else {
      // Image is pillarboxed. It fills the height of the div and is centered horizontally.
      const scaleFactor = rect.height / this.image.naturalHeight;
      const imageWidth = this.image.naturalWidth * scaleFactor;
      const horizontalOffset = (rect.width - imageWidth) / 2;

      // Calculate the click position relative to the image
      imageX = x - horizontalOffset;
      imageY = y;
    }

    // Check if the click is within the image boundaries
    if (
      imageX < 0 ||
      imageY < 0 ||
      imageX > this.image.naturalWidth ||
      imageY > this.image.naturalHeight
    ) {
      console.log("Click is outside the image");
      return;
    }

    console.log(imageX, imageY);
    const clicks = [];

    clicks.push({ x: imageX, y: imageY });
    this.setState({ clicks: clicks });

    const inputs = modelData({
      clicks: clicks,
      tensor: this.state.tensor,
      modelScale: this.state.modelScale,
    });
    console.log("click");
    console.log(
      imageX * this.state.modelScale.samScale,
      imageY * this.state.modelScale.samScale
    );
    const results = await this.state.model.run(inputs);
    const output = results[this.state.model.outputNames[0]];

    const image = onnxMaskToImage(output.data, output.dims[2], output.dims[3]);
    const masks = this.state.masks;
    masks.push(image);
    this.setState({ masks: masks });
  };

  loadParameters = async () => {
    try {
      const variation = await readVariation(this.props.variationId);
      this.setState({ variation });
    } catch (error) {
      console.error("Error:", error);
    }
  };
  async componentDidMount() {
    console.log("mounted");
    this.loadParameters();
    this.loadEmbeddings().then((embeddings) => {
      console.log(embeddings);
      const reshapedEmbeddings = new Float32Array(
        embeddings.embeddings.flat(3)
      );

      const tensor = new ort.Tensor(
        "float32",
        reshapedEmbeddings,
        [1, 256, 64, 64]
      );

      console.log(tensor);

      this.setState({ tensor });
    });

    this.initModel();
  }

  updateVariation = (key, value) => {
    try {
      const variation = this.state.variation;
      variation[key] = value;

      updateVariation(this.props.variationId, variation);
    } catch (error) {
      console.log(error);
    }
  };
  loadEmbeddings = async () => {
    try {
      this.setState({ loadingEmbeddings: true });
      const embeddings = await getImageEmbeddings(this.props.imageFileId);
      this.setState({ loadingEmbeddings: false });
      return embeddings;
    } catch (e) {}
  };

  async launch() {
    const finalMask = combineMasks(
      this.state.masks,
      this.state.modelScale.originalWidth,
      this.state.modelScale.originalHeight
    );

    await this.props.launchComplexVariation(finalMask);
  }

  render() {
    return (
      <div
        style={{
          display: "flex",
          flexDirection: "row",

          height: "100%",
          width: "100%",
          position: "relative",
        }}
      >
        <div
          onClick={() => this.props.deleteVariation()}
          style={{
            cursor: "pointer",
            width: 20,
            height: 20,
            position: "absolute",
            right: 0,
            top: -20,
          }}
        >
          <MdDelete size={20} color={colors.brown}></MdDelete>
        </div>
        <div
          className="p-2"
          style={{
            display: "flex",
            flexDirection: "row",

            height: "100%",
            width: "70%",
            borderRightWidth: 1,
            borderRightColor: colors.white,
            borderRightStyle: "solid",
            justifyContent: "center",
            alignItems: "center",
          }}
        >
          {this.state.loadingEmbeddings ? (
            <div
              style={{
                display: "flex",
                justifyContent: "center",
                alignItems: "center",
                position: "relative",
              }}
            >
              <Spinner variant="light"></Spinner>
              <div style={{ position: "absolute", bottom: -20 }}>
                <span
                  style={{
                    fontAlign: "left",
                    fontFamily: "Montserrat",
                    fontWeight: fontWeights.bold,
                    color: colors.brown,
                    fontSize: fontSizes.small,
                  }}
                >
                  Traitement...
                </span>
              </div>
            </div>
          ) : (
            <div
              onClick={this.handleImageClick}
              style={{ width: "100%", height: "100%", position: "relative" }}
            >
              <img
                ref={(ref) => (this.image = ref)}
                src={this.props.url}
                style={{
                  width: "100%",
                  height: "100%",
                  objectFit: "contain",
                  position: "relative",
                }}
                onLoad={(event) => {
                  const samScale = handleImageScale(event.target);
                  this.setState({
                    modelScale: samScale,
                  });
                }}
                onResize={(event) => {
                  console.log("resize");
                  const samScale = handleImageScale(event.target);
                  this.setState({
                    modelScale: samScale,
                  });
                }}
              ></img>
              {this.state.masks.map((mask) => (
                <img
                  style={{
                    opacity: 0.5,
                    width: "100%",
                    top: 0,
                    left: 0,
                    height: "100%",
                    objectFit: "contain",
                    position: "absolute",
                  }}
                  src={mask.src}
                />
              ))}
            </div>
          )}
        </div>

        <div
          className="p-2"
          style={{
            display: "flex",
            flexDirection: "column",
            justifyContent: "space-between",
            height: "100%",
            width: "30%",
          }}
        >
          <div
            style={{
              display: "flex",
              flexDirection: "column",
              width: "100%",
            }}
          >
            <span
              className="mb-1"
              style={{
                fontAlign: "left",
                fontFamily: "Montserrat",
                fontWeight: fontWeights.bold,
                color: colors.lightBlue,
                fontSize: fontSizes.normal,
                textAlign: "left",
              }}
            >
              Contrôles
            </span>
            <div
              className="mb-2"
              style={{
                width: "100%",
                display: "flex",
                flexDirection: "row",
                justifyContent: "space-around",
              }}
            >
              <IoCaretBackCircle
                style={{ cursor: "pointer" }}
                onClick={() => {
                  let masks = [...this.state.masks];
                  masks.pop();
                  this.setState({ masks: masks });
                }}
                color={colors.brown}
                size={30}
              ></IoCaretBackCircle>
              <IoPlayBackCircle
                onClick={() => {
                  this.setState({ masks: [] });
                }}
                style={{ cursor: "pointer" }}
                color={colors.brown}
                size={30}
              ></IoPlayBackCircle>
            </div>
            <span
              className="mb-1"
              style={{
                fontAlign: "left",
                fontFamily: "Montserrat",
                fontWeight: fontWeights.bold,
                color: colors.lightBlue,
                fontSize: fontSizes.normal,
                textAlign: "left",
              }}
            >
              Modification
            </span>
            <div
              className="p-2 mt-1"
              style={{
                display: "flex",
                width: "100%",
                borderRadius: 10,
                borderColor: colors.lightBlue,
                borderWidth: 1,
                borderStyle: "solid",
                overflowWrap: "break-word",
              }}
            >
              <textarea
                onBlur={this.updateVariation.bind(
                  this,
                  "mj_zone_prompt",
                  this.state.variation?.mj_zone_prompt
                )}
                style={{
                  width: "100%",
                  backgroundColor: colors.darkBlue,
                  borderWidth: 0,
                  color: colors.white,
                  fontFamily: "Montserrat",
                  fontSize: fontSizes.small,
                  whiteSpace: "normal",
                  wordWrap: "break-word",
                }}
                defaultValue={this.state.variation?.mj_zone_prompt}
                value={this.state.variation?.mj_zone_prompt}
                onChange={(event) =>
                  this.setState((prevState) => ({
                    ...prevState,
                    variation: {
                      ...prevState.variation,
                      mj_zone_prompt: event.target.value,
                    },
                  }))
                }
              />
            </div>
          </div>
          <Button
            enabled={this.state.masks.length > 0}
            title={"Lancer"}
            width={"100%"}
            onClick={() => this.launch()}
          ></Button>
        </div>
      </div>
    );
  }
}

export default ZoneVariation;
