import React, { Component } from 'react';
import { colors } from '../colors';
import { InferenceSession, Tensor } from 'onnxruntime-web';
import npyjs from 'npyjs';
import { FaCommentSlash, FaMicrosoft } from 'react-icons/fa';

export async function initModel() {
  try {
    const URL = 'models/export.onnx';
    const model = await InferenceSession.create(URL);
    return model;
  } catch (e) {
    console.log(e);
  }
}

const ort = require('onnxruntime-web');
function arrayToImageData(input, width, height) {
  const [r, g, b, a] = [0, 114, 189, 255]; // the masks's blue color
  const arr = new Uint8ClampedArray(4 * width * height).fill(0);
  for (let i = 0; i < input.length; i++) {
    // Threshold the onnx model mask prediction at 0.0
    // This is equivalent to thresholding the mask using predictor.model.mask_threshold
    // in python
    if (input[i] > 0.0) {
      arr[4 * i + 0] = r;
      arr[4 * i + 1] = g;
      arr[4 * i + 2] = b;
      arr[4 * i + 3] = a;
    }
  }
  return new ImageData(arr, height, width);
}

// Use a Canvas element to produce an image from ImageData
function imageDataToImage(imageData) {
  const canvas = imageDataToCanvas(imageData);
  const image = new Image();
  image.src = canvas.toDataURL();
  return image;
}

// Canvas elements can be created from ImageData
function imageDataToCanvas(imageData) {
  const canvas = document.createElement('canvas');
  const ctx = canvas.getContext('2d');
  canvas.width = imageData.width;
  canvas.height = imageData.height;
  ctx?.putImageData(imageData, 0, 0);
  return canvas;
}

// Convert the onnx model mask output to an HTMLImageElement
export function onnxMaskToImage(input, width, height) {
  return imageDataToImage(arrayToImageData(input, width, height));
}

const modelData = ({ clicks, tensor, modelScale }) => {
  const imageEmbedding = tensor;
  let pointCoords;
  let pointLabels;
  let pointCoordsTensor;
  let pointLabelsTensor;

  // Check there are input click prompts
  if (clicks) {
    let n = clicks.length;

    // If there is no box input, a single padding point with
    // label -1 and coordinates (0.0, 0.0) should be concatenated
    // so initialize the array to support (n + 1) points.
    pointCoords = new Float32Array(2 * (n + 1));
    pointLabels = new Float32Array(n + 1);

    // Add clicks and scale to what SAM expects
    for (let i = 0; i < n; i++) {
      pointCoords[2 * i] = clicks[i].x * modelScale.samScale;
      pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale;
      pointLabels[i] = clicks[i].clickType;
    }

    // Add in the extra point/label when only clicks and no box
    // The extra point is at (0, 0) with label -1
    pointCoords[2 * n] = 0.0;
    pointCoords[2 * n + 1] = 0.0;
    pointLabels[n] = -1.0;

    // Create the tensor
    pointCoordsTensor = new Tensor('float32', pointCoords, [1, n + 1, 2]);
    pointLabelsTensor = new Tensor('float32', pointLabels, [1, n + 1]);
  }
  const imageSizeTensor = new Tensor('float32', [
    modelScale.originalHeight,
    modelScale.originalWidth,
  ]);

  if (pointCoordsTensor === undefined || pointLabelsTensor === undefined)
    return;

  // There is no previous mask, so default to an empty tensor
  const maskInput = new Tensor(
    'float32',
    new Float32Array(256 * 256),
    [1, 1, 256, 256]
  );
  // There is no previous mask, so default to 0
  const hasMaskInput = new Tensor('float32', [0]);

  return {
    image_embeddings: imageEmbedding,
    point_coords: pointCoordsTensor,
    point_labels: pointLabelsTensor,
    orig_im_size: imageSizeTensor,
    mask_input: maskInput,
    has_mask_input: hasMaskInput,
  };
};

const imgTest = require('../assets/truck.jpg');
const handleImageScale = (image) => {
  // Input images to SAM must be resized so the longest side is 1024
  const LONG_SIDE_LENGTH = 1024;
  let w = image.width;
  let h = image.height;
  const samScale = LONG_SIDE_LENGTH / Math.max(h, w);
  return {
    height: h,
    width: w,
    samScale,
    originalHeight: image.naturalHeight,
    originalWidth: image.naturalWidth,
  };
};

class Prout extends Component {
  constructor(props) {
    super(props);
    this.state = {
      image: null,
      modelScale: null,
      clickedPixel: null,
      clicks: [],
    };
    this.image = React.createRef();
  }

  initModel = async () => {
    try {
      const URL = 'models/export.onnx';
      const model = await InferenceSession.create(URL);
      this.setState({ model: model });
    } catch (e) {
      console.log(e);
    }
  };
  async componentDidMount() {
    this.initModel();

    const tensor = await this.loadNpyTensor(
      'models/embeddings_test.npy',
      'float32'
    );

    this.setState({ tensor });
  }

  loadNpyTensor = async (tensorFile, dType) => {
    let npLoader = new npyjs();
    const npArray = await npLoader.load(tensorFile);
    const tensor = new ort.Tensor(dType, npArray.data, npArray.shape);
    return tensor;
  };

  handleImageClick = async (event) => {
    const x = event.nativeEvent.offsetX;
    const y = event.nativeEvent.offsetY;
    const clicks = [];

    clicks.push({ x: x, y: y });
    this.setState({ clicks: clicks });

    const inputs = modelData({
      clicks: clicks,
      tensor: this.state.tensor,
      modelScale: this.state.modelScale,
    });
    console.log('click');

    const results = await this.state.model.run(inputs);
    const output = results[this.state.model.outputNames[0]];

    const image = onnxMaskToImage(output.data, output.dims[2], output.dims[3]);

    this.setState({ mask: image });
  };

  render() {
    return (
      <div
        style={{
          backgroundColor: 'red',
          display: 'flex',
          width: '100vw',
          height: '100vh',
          backgroundColor: colors.darkBlue,
          flexDirection: 'column',
          position: 'relative',
        }}
      >
        <div
          onClick={this.handleImageClick}
          style={{ width: 600, position: 'relative' }}
        >
          <img
            ref={(ref) => (this.image = ref)}
            src={imgTest}
            style={{
              width: '100%',
              height: 'auto',
              objectFit: 'contain',
              position: 'relative',
            }}
            onLoad={(event) => {
              const samScale = handleImageScale(event.target);
              this.setState({
                modelScale: samScale,
              });
            }}
          ></img>
          {this.state.mask ? (
            <img
              style={{
                opacity: 0.5,
                width: '100%',
                top: 0,
                left: 0,
                height: 'auto',
                objectFit: 'contain',
                position: 'absolute',
              }}
              src={this.state.mask.src}
            />
          ) : null}
        </div>
        <h1>This is a Test Component</h1>
      </div>
    );
  }
}

export default Prout;
