import turf from "../util/paintpolygon/myTurf";
import { elen } from ".";
import * as ort from "onnxruntime-web";
import * as tf from "@tensorflow/tfjs";
import { LRUCache } from "lru-cache";
import { v4 as uuidv4 } from "uuid";
ort.env.wasm.numThreads = 4;
ort.env.wasm.proxy = false; // TODO: See WorkerGlobalScope issue, failed to parse URL from
// ort-wasm-simd-threaded.wasm
ort.env.wasm.wasmPaths = {
  "ort-wasm.wasm": "/ort-wasm.wasm",
  "ort-wasm-simd.wasm": "/ort-wasm-simd.wasm",
  "ort-wasm-threaded.wasm": "/ort-wasm-threaded.wasm",
  "ort-wasm-simd-threaded.wasm": "/ort-wasm-simd-threaded.wasm",
  "ort-training-wasm-simd.wasm": "/ort-training-wasm-simd.wasm",
};

const IMG_WIDTH = 1024;
const IMG_HEIGHT = 684;

const ENCODER_URL =
  "https://files.sunu.in/sam_vit_b_01ec64.encoder.preprocess.quant.onnx";
const DECODER_URL = "/sam_vit_b_01ec64.decoder.onnx"; // Deployed publicly

function mic(pixel, mpp) {
  return pixel * mpp;
}

const IMAGE_EMBEDDING_TYPE = "float32";
const IMAGE_EMBEDDING_DIMS = [1, 256, 64, 64];
const CACHE_BUCKET = "valar-airen";
const CACHE_REGION = "us-west-1";
const CACHE_KEY = "airen0/.tulkas/sam-embeddings-cache";

async function stream_to_buffer(readableStream) {
  return new Promise((resolve, reject) => {
    const chunks = [];
    readableStream.on("data", (data) => {
      if (typeof data === "string") {
        // Convert string to Buffer assuming UTF-8 encoding
        chunks.push(Buffer.from(data, "utf-8"));
      } else if (data instanceof Buffer) {
        chunks.push(data);
      } else {
        // Convert other data types to JSON and then to a Buffer
        const jsonData = JSON.stringify(data);
        chunks.push(Buffer.from(jsonData, "utf-8"));
      }
    });
    readableStream.on("end", () => {
      resolve(Buffer.concat(chunks));
    });
    readableStream.on("error", reject);
  });
}

/**
 * Segment Anything Model
 */
class SAM {
  constructor(store, debug = false, cache_options = { max: 42 }) {
    this._embed_cache = new LRUCache(cache_options);
    this.sessions = {
      encoder: null,
      decoder: null,
    };
    this.debug = debug;
    // Used for caching items to S3
    this.store = store;
    this._url_sign = this.store.url_sign;
  }

  async _get_url(id, action) {
    const url = `s3://${CACHE_BUCKET}/${CACHE_KEY}/${id}.bin`;
    return await this._url_sign(url, action, "application/octet-stream");
  }

  // Checks if embedding is already stored
  // in the S3 cache
  async read(id) {
    const url = await this._get_url(id, "getObject");
    let buffer, response;
    try {
      response = await fetch(url);
      if (!response.ok) {
        // Not cached
        return null;
      }
      buffer = await response.arrayBuffer();
    } catch (err) {
      console.warn(`[sam.js:SAM.read] Error fetching embedding =${id}`, err);
      return {};
    }

    const typedArrayConstructor = {
      float32: Float32Array,
      int32: Int32Array,
    }[IMAGE_EMBEDDING_TYPE];

    const typedArray = new typedArrayConstructor(buffer);
    return new ort.Tensor(
      IMAGE_EMBEDDING_TYPE,
      typedArray,
      IMAGE_EMBEDDING_DIMS
    );
  }

  async flush(image_embeddings, id) {
    const url = await this._get_url(id, "putObject");
    const buffer = image_embeddings.data.buffer;
    const request = {
      body: buffer,
      headers: {
        "Content-Type": "application/octet-stream",
      },
    };

    try {
      const response = await fetch(url, {
        method: "PUT",
        body: request.body,
        headers: {
          ...request.headers,
        },
      });
      if (!response.ok) {
        throw new Error("Issues in flushing image_embeddings cache");
      }
    } catch (err) {
      console.warn(
        `[sam.js:flush] Error flushing image_embeddings with id=${id}`,
        err
      );
    }
  }

  async encode(image_data, id = null) {
    // Check local cache first
    if (id !== null && this._embed_cache.has(id)) {
      return this._embed_cache.get(id);
    }

    // Check S3 second, then cache s3 result locally
    const cached = await this.read(id);
    if (cached !== null) {
      this._embed_cache.set(id, cached);
      return cached;
    }

    const resizedTensor = await ort.Tensor.fromImage(image_data, {
      resizedWidth: IMG_WIDTH,
      resizedHeight: IMG_HEIGHT,
    });
    const resizeImage = resizedTensor.toImageData();
    let imageDataTensor = await ort.Tensor.fromImage(resizeImage);
    if (this.debug) console.log("Image data tensor:", imageDataTensor);

    let tf_tensor = tf.tensor(imageDataTensor.data, imageDataTensor.dims);
    tf_tensor = tf_tensor.reshape([3, IMG_HEIGHT, IMG_WIDTH]);
    tf_tensor = tf_tensor.transpose([1, 2, 0]).mul(255);
    imageDataTensor = new ort.Tensor(tf_tensor.dataSync(), tf_tensor.shape);

    let session = this.sessions.encoder;
    let feeds;
    try {
      if (session === null) {
        this.sessions.encoder = session = await ort.InferenceSession.create(
          ENCODER_URL
        );
      }

      if (this.debug) console.log("Encoder Session", session);
      feeds = { input_image: imageDataTensor };
      if (this.debug)
        console.log("Computing image embedding; this will take a minute...");
    } catch (error) {
      console.log("Error creating encoder InferenceSession.", error);
      return {};
    }

    let start = Date.now();
    let results;
    let image_embeddings;
    try {
      results = await session.run(feeds);
      if (this.debug) console.log("Encoding result:", results);
      image_embeddings = results.image_embeddings;
    } catch (error) {
      console.log(`caught error: ${error}`);
    }
    let end = Date.now();
    let time_taken = (end - start) / 1000;
    console.log(`Computing image embedding took ${time_taken} seconds`);
    if (id !== null) this._embed_cache.set(id, image_embeddings);
    this.flush(image_embeddings, id);

    return image_embeddings;
  }

  async decode(image_embeddings, bbox) {
    const { x, y, w, h } = bbox;
    const rectWidth = w;
    const rectHeight = h;

    const boxCoords = new ort.Tensor(
      new Float32Array([x, y, x + rectWidth, y + rectHeight]),
      [1, 2, 2]
    );
    const boxLabels = new ort.Tensor(new Float32Array([2, 3]), [1, 2]);

    // generate a (1, 1, 256, 256) tensor with all values set to 0
    const maskInput = new ort.Tensor(
      new Float32Array(256 * 256),
      [1, 1, 256, 256]
    ); // Unused, no mask input
    const hasMask = new ort.Tensor(new Float32Array([0]), [1]); // No mask
    const originalImageSize = new ort.Tensor(
      new Float32Array([684, 1024]), // Q. why 684, 1024?
      [2]
    );

    let session = this.sessions.decoder;
    try {
      if (session === null) {
        this.sessions.decoder = session = await ort.InferenceSession.create(
          DECODER_URL
        );
      }
    } catch (error) {
      console.error("Error creating InferenceSession.", error);
      return null;
    }

    if (this.debug) console.log("Decoder session", session);
    const decodingFeeds = {
      image_embeddings: image_embeddings,
      point_coords: boxCoords,
      point_labels: boxLabels,
      mask_input: maskInput,
      has_mask_input: hasMask,
      orig_im_size: originalImageSize,
    };

    let start = Date.now();
    let end;
    let results;
    let maskImageData;
    try {
      results = await session.run(decodingFeeds);
      if (this.debug) console.log("Generated mask:", results);
      const mask = results.masks;
      maskImageData = mask.toImageData();
    } catch (error) {
      console.error(`Caught error: ${error}`);
    }
    end = Date.now();
    if (this.debug)
      console.log(`generating masks took ${(end - start) / 1000} seconds`);
    return maskImageData;
  }

  /**
   * Convert output of SAM into a multipolygon
   * Feature to be displayed.
   *
   * @param {Object} origin x/y-coordinates (pixel) of the origin of the
   *  image with respect to the georaster layer
   * @param {Number} origin.x
   * @param {Number} origin.y
   */
  postprocess(mask_image_data, origin, mpp) {
    const width = IMG_WIDTH,
      height = IMG_HEIGHT,
      data4 = mask_image_data.data, // Uint8ClampedArray
      len = width * height,
      data = new Uint8Array(len);
    for (let i = 0; i < len; ++i) {
      data[i] = data4[i << 2];
    }
    let outlinePoints = MarchingSquaresOpt.getBlobOutlinePoints(
      data,
      width,
      height
    ); // returns [x1,y1,x2,y2,x3,y3... etc.]

    const coordinates = [];

    for (let i = 0; i < outlinePoints.length; i += 2) {
      if (outlinePoints[i] && outlinePoints[i + 1]) {
        coordinates.push([
          outlinePoints[i] + origin.x,
          outlinePoints[i + 1] + origin.y,
        ]);
      }
    }

    // Ensure the polygon is closed by adding the first point at the end if it's not already closed
    if (
      coordinates.length > 0 &&
      (coordinates[0][0] !== coordinates[coordinates.length - 1][0] ||
        coordinates[0][1] !== coordinates[coordinates.length - 1][1])
    ) {
      coordinates.push(coordinates[0]);
    }

    const properties = {
      id: uuidv4(),
      state: "normal",
      type: "unlabeled",
    };
    const geojson = elen.polygon([coordinates]);
    const options = { tolerance: 0.01, highQuality: false };
    const simplified = turf.simplify(geojson, options);
    const retGeom = elen.transform(simplified.geometry, ([x, y]) => [
      mic(x, mpp),
      mic(y, mpp),
    ]);
    const retGeojson = elen.polygon(retGeom.coordinates, properties);
    return retGeojson;
  }
}

export default SAM;
