import React, { useEffect, useRef, useState, useCallback } from "react";
import { Heatmap, HeatmapProps } from "./Heatmap";
import { useObserveSizeOf } from "./useObserveSizeOf";
import { ScaleContinuousNumeric } from "d3-scale";
import { select } from "d3-selection";
import { axisBottom, axisLeft } from "d3-axis";

export interface HeatmapWithAxisProps {
  heatmapProps: HeatmapProps;
  /* eslint-disable @typescript-eslint/no-explicit-any */
  xScale?: ScaleContinuousNumeric<any, any, any>;
  /* eslint-disable @typescript-eslint/no-explicit-any */
  yScale?: ScaleContinuousNumeric<any, any, any>;
  fullWidth?: boolean;
}

const AXIS_SPACE = 40; //px

export const HeatmapWithAxis: React.FC<HeatmapWithAxisProps> = ({
  heatmapProps,
  xScale,
  yScale,
  children,
  fullWidth,
}) => {
  const xAxisRef = useRef<SVGGElement>(null);
  const yAxisRef = useRef<SVGGElement>(null);
  const [wrapper, setWrapper] = useState<HTMLDivElement | undefined>(undefined);
  const wrapperSize = useObserveSizeOf(wrapper);

  const wrapperCallbackRef = useCallback((node) => {
    setWrapper(node);
  }, []);

  const scale = !fullWidth ? 1 : (wrapperSize[0] - AXIS_SPACE) / heatmapProps.width;

  const scaledWidth = heatmapProps.width * scale;
  const scaledHeight = heatmapProps.height * scale;

  useEffect(() => {
    if (xScale && xAxisRef.current) {
      const xAxis = select(xAxisRef.current);

      const originalRange = xScale.range();
      const xAxisScale = xScale.copy().range([0, originalRange[1] * scale]);
      xAxis.attr("transform", `translate(${AXIS_SPACE},${scaledHeight + AXIS_SPACE})`).call(axisBottom(xAxisScale));
    }
  }, [xScale, heatmapProps.height, scale, scaledHeight]);

  useEffect(() => {
    if (yScale && yAxisRef.current) {
      const yAxis = select(yAxisRef.current);
      const originalRange = yScale.range();
      const yAxisScale = yScale.copy().range([0, originalRange[1] * scale]);
      yAxis.attr("transform", `translate(${AXIS_SPACE}, ${AXIS_SPACE})`).call(axisLeft(yAxisScale));
    }
  }, [yScale, heatmapProps.width, scale]);

  return (
    <div
      style={{
        display: fullWidth ? "block" : "inline-block",
        position: "relative",
        paddingBottom: `${AXIS_SPACE}px`,
        paddingTop: `${AXIS_SPACE}px`,
        paddingLeft: `${AXIS_SPACE}px`,
        height: `${scaledHeight}px`,
      }}
      ref={wrapperCallbackRef}
    >
      <svg
        style={{
          position: "absolute",
          top: `0`,
          left: `0`,
        }}
        width={scaledWidth + AXIS_SPACE * 2}
        height={scaledHeight + AXIS_SPACE * 2}
      >
        {xScale && <g ref={xAxisRef} />}
        {yScale && <g ref={yAxisRef} />}
      </svg>
      <div
        style={{
          display: "block",
          transform: `scale(${scale}) translate(-${Math.max(0, (heatmapProps.width - scaledWidth) / 2)}px, 0)`,
          transformOrigin: "center top",
        }}
      >
        <Heatmap {...heatmapProps}>{children}</Heatmap>
      </div>
    </div>
  );
};
