import { useTheme } from '@mui/material';
import { getViewportForBounds, useReactFlow, Viewport } from '@xyflow/react';
import { useCallback, useEffect, useState } from 'react';
import { UseMeasureRect } from 'react-use/lib/useMeasure';

import { useViewportTransform } from '@/components/diagrams/FlowChart/hooks/useViewportTransform';

import { EntityDiagramEffectFn, FitViewEffect } from '../types';

export function useFitViewEffect({
  containerDimensions,
}: {
  containerDimensions: UseMeasureRect;
}): EntityDiagramEffectFn<FitViewEffect> {
  const theme = useTheme();

  const { setViewport, getNodes, getNodesBounds } = useReactFlow();
  const { getTransform } = useViewportTransform();
  // Holds the initial viewport so we can navigate back to it later
  const [fitViewViewport, setFitViewViewport] = useState<Viewport | null>(null);
  useEffect(() => {
    // Reset the fit view viewport when the container dimensions change
    setFitViewViewport(null);
  }, [containerDimensions]);

  return useCallback<EntityDiagramEffectFn<FitViewEffect>>(
    (_state, { perserveZoomLevel, kind, presentationMode }, _dispatch) => {
      const nodes = getNodes();
      const currentZoom = getTransform()[2];
      const minZoom = 0.5;
      const maxZoom = presentationMode ? 1.25 : 2;

      const zoomLevelOpts = {
        minZoom: perserveZoomLevel ? currentZoom : minZoom,
        maxZoom: perserveZoomLevel ? currentZoom : maxZoom,
      };

      // Note, we want to exclude anything that is not a tile node from the fit view
      // calculations for centering in the X direction
      const nodesToFitX = nodes.flatMap((node) => {
        if (node.hidden || node.type !== 'tile') {
          return [];
        }

        return node;
      });

      // We care about all node types for centering in the Y direction
      // except for hidden ones
      const nodesToFitY = nodes.flatMap((node) => {
        if (node.hidden) {
          return [];
        }
        return node;
      });

      // Get the bounds of the nodes relevant nodes
      const nodesBoundsX = getNodesBounds(nodesToFitX);
      const nodesBoundsY = getNodesBounds(nodesToFitY);

      // This ratio is added to the bounds to create a buffer around the nodes
      const paddingRatio = 0.1;

      const viewportForBoundsX = getViewportForBounds(
        nodesBoundsX,
        containerDimensions.width,
        containerDimensions.height,
        zoomLevelOpts.minZoom,
        zoomLevelOpts.maxZoom,
        paddingRatio
      );

      // Use this for everything except determing the viewport X
      // because while we want the section nodes to be in the viewport
      // in the Y direction, we don't care about fitting them in the X direction
      // since they are dynamically resized and their width is irrelevant
      // to the viewport calculations
      const viewportForBoundsY = getViewportForBounds(
        nodesBoundsY,
        containerDimensions.width,
        containerDimensions.height,
        zoomLevelOpts.minZoom,
        zoomLevelOpts.maxZoom,
        paddingRatio
      );

      const viewportX = viewportForBoundsX.x;
      const viewportY = -1 * nodesBoundsY.y * viewportForBoundsY.zoom;
      // Applies a bit of space between the top of the container and the top of the nodes
      const topBuffer = presentationMode ? 40 : 80;

      const finalViewport = fitViewViewport
        ? fitViewViewport
        : {
            x: viewportX,
            y: viewportY + topBuffer,
            zoom: viewportForBoundsY.zoom,
          };

      if (!fitViewViewport) {
        setFitViewViewport(finalViewport);
      }

      void setViewport(finalViewport, {
        duration:
          kind === 'init'
            ? undefined
            : theme.transitions.duration.enteringScreen,
      });
    },
    [
      containerDimensions,
      fitViewViewport,
      getNodes,
      getNodesBounds,
      getTransform,
      setViewport,
      theme.transitions.duration.enteringScreen,
    ]
  );
}
