import {
  chunk,
  compact,
  floor,
  forEach,
  groupBy,
  maxBy,
  orderBy,
  uniqBy,
} from 'lodash';

import { Edge, TileNode } from '../types';
import { getBoundingBoxForNodes, getNodeMeasuredWidthHeight } from './nodes';

// Returns a TileNode[][] where each row has a max size of chunkSize
// and a row can only contain nodes belong to the same section.
// The nodes are sorted by height, so the larger nodes will likely
// be on the last row of a section. We also sort the nodes by the
// row index of their soruce node if they have an incoming edge.
// This helps to reduce the likelihood of edges crossing.
function getGridFromNodes(nodes: TileNode[], edges: Edge[], chunkSize: number) {
  const edgesBySource: Record<string, Edge[]> = groupBy(
    edges,
    (edge) => edge.source
  );

  const edgesByDestination: Record<string, Edge[]> = groupBy(
    edges,
    (edge) => edge.target
  );

  // Group nodes by section
  const sectionMap = groupBy(nodes, (node) => node.id.split(':')[0]);

  let grid: TileNode[][] = [];

  // Chunk nodes within each section
  forEach(sectionMap, (sectionNodes) => {
    // Sort nodes by height so nodes of similar height are next to each other
    const sortedSectionNodesSize = orderBy(
      sectionNodes,
      (node) => getNodeMeasuredWidthHeight(node).height
    );

    // If the node is a source, move it to the end of the array
    const sortedSectionNodes = orderBy(sortedSectionNodesSize, (node) =>
      edgesBySource[node.id] ? 1 : 0
    );

    const chunked = chunk(sortedSectionNodes, chunkSize);

    grid.push(...chunked);
  });

  // Within each row, sort the nodes by the order of the source node
  // so the edges are less likely to cross
  grid = grid.map((row) =>
    orderBy(row, (node) => {
      // Get index of the first source node
      const sourceNodeId = edgesByDestination[node.id]?.[0]?.source;

      if (!sourceNodeId) {
        // If there is no source node, push the node to the end of the row
        return row.length;
      }

      const sourceNodeRowIndex = grid
        .find((row) => row.some((node) => node.id === sourceNodeId))
        ?.findIndex((node) => node.id === sourceNodeId);

      if (sourceNodeRowIndex === undefined || sourceNodeRowIndex === -1) {
        return row.length;
      }

      return sourceNodeRowIndex;
    })
  );

  // Finally, within each row, sort the nodes by the external order parameter
  grid = grid.map((row) =>
    orderBy([...row.map((r, idx) => ({ ...r, idx }))], (node) => {
      const orderByNum = node.data.order ?? node.idx;
      return orderByNum;
    })
  );

  return grid;
}

// This function centers the nodes in the grid.
// It does this by calculating the center of the row and the center of the above row
// and then subtracting the center of the row from the center of the above row
// to get the xOffset. This xOffset is then applied to the nodes in the row.
function applyCenterGridNodes(grid: TileNode[][]) {
  // Center the nodes in the grid
  const getCenterOfRow = (row: TileNode[]) => {
    const firstX =
      row.find((node) => !node.data.omitFromCentering)?.position.x ?? 0;
    const lastX =
      [...row].reverse().find((node) => !node.data.omitFromCentering)?.position
        .x ?? 0;
    return floor((firstX + lastX) / 2, 1);
  };

  grid.forEach((row) => {
    const aboveRow = grid[grid.indexOf(row) - 1];
    let xOffset = 0;

    if (aboveRow) {
      const centerOfAboveRow = getCenterOfRow(aboveRow);
      const centerOfRow = getCenterOfRow(row);
      xOffset = centerOfRow - centerOfAboveRow;
    }

    row.forEach((node) => {
      node.position.x -= xOffset;
    });
  });
}

// This function applies the grid layout to the nodes in the grid.
// It does this by calculating the x and y position for each node.
// If centerGrid is true, it will also center the nodes in the grid.
function applyGridLayoutNodes(
  grid: TileNode[][],
  paddingX: number,
  paddingY: number,
  centerGrid = true
) {
  // Track yOffset for the row
  let yOffset = 0;

  grid.forEach((row) => {
    // Track xOffset for the node
    let xOffset = 0;

    const maxHeight = (() => {
      const maxHeightNode = maxBy(
        row,
        (node) => getNodeMeasuredWidthHeight(node).height
      );

      if (!maxHeightNode) {
        return 0;
      }

      return getNodeMeasuredWidthHeight(maxHeightNode).height;
    })();

    row.forEach((node, colIdx) => {
      const previousNode = row[colIdx - 1];

      if (node.data.stackY && previousNode?.data.stackY) {
        // This means we need to stack this node under the previous node
        node.position.x = previousNode?.position.x ?? 0;
        node.position.y = previousNode?.position.y ?? 0;
        // Add the height of the previous node to the y position
        node.position.y +=
          getNodeMeasuredWidthHeight(previousNode).height + paddingY / 2;
      } else {
        node.position = { x: xOffset, y: yOffset };
        const { width } = getNodeMeasuredWidthHeight(node);
        xOffset += floor(width + paddingX, 1); // Apply paddingX between nodes
      }
    });

    yOffset += maxHeight + paddingY; // Apply paddingY between rows
  });

  const uniqueSectionIds = compact(
    uniqBy(grid.flat(), (node) => node.data.sectionLabelId).map(
      (node) => node.data.sectionLabelId
    )
  );

  const sectionHeights: Record<string, number> = {};

  uniqueSectionIds.forEach((sectionId) => {
    const boundsForSection = getBoundingBoxForNodes(
      grid.flat().filter((n) => n.data.sectionLabelId === sectionId)
    );
    sectionHeights[sectionId] = boundsForSection.bottom - boundsForSection.top;
  });

  // This is a map of column index to nodes in that column
  const nodeColumns: Record<number, TileNode[]> = [];
  let colIdx = 0;

  // Build a record of the nodes indexed by their column index
  grid.forEach((row) => {
    row.forEach((node) => {
      if (node.data.stackY && (nodeColumns[colIdx]?.length ?? 0) > 0) {
        nodeColumns[colIdx]?.push(node);
      } else if (node.data.stackY) {
        nodeColumns[colIdx] = [node];
      } else {
        colIdx++;
        nodeColumns[colIdx] = [node];
        colIdx++;
      }
    });
  });

  // Additionally, if a node has a centerY property, we need to center its column
  Object.values(nodeColumns).forEach((column) => {
    const shouldCenterColumn = column.some((node) => node.data.centerY);

    if (!shouldCenterColumn || !column[0]?.data.sectionLabelId) {
      return;
    }

    const columnHeight =
      getBoundingBoxForNodes(column).bottom -
      getBoundingBoxForNodes(column).top;
    const sectionHeight = sectionHeights[column[0].data.sectionLabelId];

    if (!sectionHeight) {
      return;
    }

    const centerY = (sectionHeight - columnHeight) / 2;

    column[0].position.y = centerY;
  });

  if (centerGrid) {
    applyCenterGridNodes(grid);
  }
}

// This function applies the modified edge center factor to the edges.
// It does this by calculating the y travel of the edge and then applying
// the centerFactor to the edge. The centerFactor is calculated based on the
// y travel of the edge such that the edge will break closer to the target node
// as the y travel increases.
function applyModifiedEdgeCenterFactor(grid: TileNode[][], edges: Edge[]) {
  // Set the centerFactor for edges based on how long the Y travel of the edge is
  edges.forEach((edge) => {
    const sourceNode = grid.find((row) =>
      row.some((node) => node.id === edge.source)
    )?.[0];
    const targetNode = grid.find((row) =>
      row.some((node) => node.id === edge.target)
    )?.[0];

    const topOfTargetNode = targetNode?.position.y ?? 0;
    const bottomOfSourceNode =
      (sourceNode?.position.y ?? 0) + (sourceNode?.measured?.height ?? 0);

    const yTravel = topOfTargetNode - bottomOfSourceNode;

    if (edge.data) {
      // The border radius of our edges is 16px
      // so always leave 32px of space on the top and bottom

      edge.data.centerFactor = (yTravel - 32) / yTravel;
    }
  });
}

interface GridLayoutProps {
  nodes: TileNode[];
  chunkSize: number;
  paddingX: number;
  paddingY: number;
  edges: Edge[];
  applyModifyedEdgeCenterFactor?: boolean;
}

// Modifies nodes such that they are positioned in a grid and edges are modified
// to break closer to the target node as the y travel of the edge increases.
export function gridLayout({
  nodes,
  chunkSize,
  paddingX,
  paddingY,
  edges,
  applyModifyedEdgeCenterFactor = true,
}: GridLayoutProps): {
  grid: TileNode[][];
  edges: Edge[];
} {
  // Derive grid data structure from nodes, edges, and chunkSize
  const grid: TileNode[][] = getGridFromNodes(nodes, edges, chunkSize);

  // Apply the grid layout to the nodes in the grid including centering
  applyGridLayoutNodes(grid, paddingX, paddingY);
  // Apply the modified edge center factor
  if (applyModifyedEdgeCenterFactor) {
    applyModifiedEdgeCenterFactor(grid, edges);
  }

  // Apply the external xOffset to the nodes
  grid.forEach((row) => {
    row.forEach((node) => {
      node.position.x = node.position.x + (node.data.xOffset ?? 0);
    });
  });

  // Apply the external yOffset to the nodes
  grid.forEach((row) => {
    row.forEach((node) => {
      node.position.y = node.position.y + (node.data.yOffset ?? 0);
    });
  });

  return {
    grid,
    edges,
  };
}
