import { Position } from '@xyflow/react';
import { getYear } from 'date-fns';
import Decimal from 'decimal.js';
import { compact, groupBy, keyBy, uniq } from 'lodash';

import { TileBeforeAndAfter } from '@/components/diagrams/components/Tile/TileBeforeAndAfter';
import { TileVariant } from '@/components/diagrams/components/Tile/types';
import { Edge } from '@/components/diagrams/FlowChart';
import { BadgeVariants } from '@/components/notifications/Badge/Badge';
import { ContextPrimaryClient } from '@/modules/household/contexts/householdDetails.context';
import { AfterDeath, EstateWaterfallEdgeKind } from '@/types/schema';
import { assertNonNil } from '@/utils/assertUtils';
import { sumDecimalJS } from '@/utils/decimalJSUtils';
import { formatCurrency } from '@/utils/formatting/currency';
import { FlowChartGraph } from '@/utils/graphology/FlowChartGraph';
import { getNodes } from '@/utils/graphqlUtils';

import {
  getCombinedTaxStateAndFederal,
  getSectionLabel,
  isHypotheticalWaterfall,
} from '../EstateWaterfall.utils';
import {
  EstateWaterfall_NodeFragment,
  EstateWaterfallFragment,
} from '../graphql/EstateWaterfall.generated';
import {
  EstateWaterfallGraph,
  EstateWaterfallGraphAttributes,
  EstateWaterfallGraphEdgeAttributes,
  EstateWaterfallGraphNodeAttributes,
  GraphNodeCategorizationType,
  WaterfallFeatureFlags,
} from '../types';
import {
  ESTATE_TAX_NODE_IDENTIFIER,
  EXTERNAL_TRANSFER_NODE_IDENTIFIER,
} from './constants';
import { SHOW_CREATE_GROUP_MODAL_SENTINEL } from './CreateNewGroupModal.utils';
import { drawEdges } from './drawEdges';
import { GroupNodeItems } from './GroupNodeItems';
import {
  addTaxEdge,
  buildTileFromNode,
  getBeforeAndAfterData,
  getCategorizationType,
  getEstateTaxAfterDeath,
  getExternalTransferNodeId,
  getGiftTaxNodeId,
  getGroupNodeTileVariant,
  getNodeId,
  getTaxNodeId,
  isTaxNode,
} from './utils';

// TODO (LUM-2026): Very duplicative with entity map, share this
export type EdgeInput = Omit<Edge, 'id' | 'type'>;
export const createEdge = ({
  source,
  target,
  data,
  ...edge
}: EdgeInput): Edge => ({
  id: `${source}:${target}`,
  source,
  target,
  type: 'arrow',
  data: { hideLabel: true, ...data },
  sourceHandle: Position.Bottom,
  targetHandle: Position.Top,
  ...edge,
});

interface CreateWaterfallGraphInput {
  waterfall: EstateWaterfallFragment;
  grantors: ContextPrimaryClient[];
  isTwoClientHousehold: boolean;
  featureFlags?: WaterfallFeatureFlags;
  visibleNodeIds: string[];
  isGeneratingDefaultState?: boolean;
  initWaterfallNodeIds: Set<string>;
}

export function createWaterfallGraph({
  grantors,
  waterfall,
  isTwoClientHousehold,
  visibleNodeIds,
  isGeneratingDefaultState = false,
  initWaterfallNodeIds,
}: CreateWaterfallGraphInput): EstateWaterfallGraph {
  const graph = new FlowChartGraph<
    EstateWaterfallGraphNodeAttributes,
    EstateWaterfallGraphEdgeAttributes,
    EstateWaterfallGraphAttributes
  >();

  const {
    visualizationWithProjections,
    visualizationConfig,
    firstGrantorDeath,
    firstGrantorDeathYear,
    secondGrantorDeathYear,
    visualizationGroups,
    outgoingExternalTransfers,
    incomingExternalTransfers,
    hypotheticalTransfers,
  } = waterfall;

  const viz = visualizationWithProjections;

  const vizNodes = viz.nodes;

  const {
    beforeFirstDeathTaxSummary,
    firstDeathTaxSummary,
    secondDeathTaxSummary,
  } = viz;
  const groupedNodes = viz.nodes.filter((n) => !!n.group?.id);

  function nodeIdHasConfiguration(nodeId: string) {
    return waterfall.nodeConfigurations.some(
      (config) => config.nodeID === nodeId
    );
  }

  const noneOfTheVizNodesHaveConfigurations = (() => {
    const numVizNodesWithConfig = viz.nodes.filter((vizNode) =>
      nodeIdHasConfiguration(
        getNodeId({
          id: vizNode.group?.id ?? vizNode.id,
          afterDeath: vizNode.afterDeath,
        })
      )
    ).length;

    return numVizNodesWithConfig === 0;
  })();

  function getIsNewNodeById(id: string) {
    if (noneOfTheVizNodesHaveConfigurations) {
      // If we don't have any node configurations, only nodes not included in the default state are new
      if (isGeneratingDefaultState) {
        // If we're generating the initial graph with no configurations, so none of the nodes should be considered new
        return false;
      }

      if (!initWaterfallNodeIds.has(id) && !isTaxNode(id)) {
        // This node was not included in the initial state and should be considered new
        return true;
      }

      // All other nodes are not considered new
      return false;
    }
    return !nodeIdHasConfiguration(id);
  }

  // { None:groupviz_uuid1: Node[], First:groupviz_uuid1: Node[] }
  const groupedNodesByGroupNodeId = groupBy(groupedNodes, (n) =>
    // Nodes were pre-filtered to guarantee they include a group id
    getNodeId({ id: n.group!.id, afterDeath: n.afterDeath })
  ) as Record<string, EstateWaterfall_NodeFragment[]>;

  // { None:groupviz_uuid1: GraphVizGroup, First:groupviz_uuid1: GraphVizGroup }
  const visualizationGroupsByGroupNodeId = keyBy(
    getNodes(visualizationGroups),
    (g) => getNodeId({ id: g.id, afterDeath: g.namespace as AfterDeath })
  );

  // Always show at least the first two sections
  const deathSections = [AfterDeath.None, AfterDeath.First];
  if (isTwoClientHousehold) {
    deathSections.push(AfterDeath.Second);
  }
  const parentNodeIds = new Set(deathSections);

  const firstDeathGrantor = grantors.find((g) => g.id === firstGrantorDeath.id);
  const secondDeathGrantor = grantors.find(
    (g) => g.id !== firstGrantorDeath.id
  );

  // Build the section group nodes
  parentNodeIds.forEach((parentNodeId) => {
    const afterDeath = parentNodeId;

    const label = getSectionLabel({
      afterDeath,
      firstDeathName: firstDeathGrantor?.firstName,
      secondDeathName: secondDeathGrantor?.firstName,
      firstDeathYear: firstGrantorDeathYear || getYear(new Date()),
      secondDeathYear: secondGrantorDeathYear || getYear(new Date()),
    });

    graph.addNodeSafe(parentNodeId, {
      data: { afterDeath, id: parentNodeId },
      node: {
        id: parentNodeId,
        position: { x: 0, y: 0 },
        data: { label },
        type: 'sectionLabel',
        draggable: false,
      },
      categorizationType: GraphNodeCategorizationType.SectionLabel,
    });
  });

  // Build the external transfer nodes
  const outgoingTransferSourceNodes = compact(
    getNodes(outgoingExternalTransfers).map((transfer) => {
      const sourceId =
        transfer.sourceEntity?.id || transfer.sourceIndividual?.id;
      if (!sourceId) return null;

      if (
        !visibleNodeIds.includes(
          getNodeId({ id: sourceId, afterDeath: AfterDeath.None })
        )
      ) {
        return null;
      }

      const transferAmount = transfer.transferValue || new Decimal(0);

      if (transferAmount.lessThanOrEqualTo(0)) {
        return null;
      }

      return {
        sourceId,
        transferAmount,
      };
    })
  );

  const outgoingTransferNodeId = getExternalTransferNodeId('outgoing');
  if (outgoingTransferSourceNodes.length) {
    graph.addNodeSafe(outgoingTransferNodeId, {
      data: {
        afterDeath: AfterDeath.None,
        id: EXTERNAL_TRANSFER_NODE_IDENTIFIER,
      },
      node: {
        id: outgoingTransferNodeId,
        type: 'tile',
        position: { x: 0, y: 0 },
        data: {
          lineOne: 'External destination',
          lineTwo: '',
          lineThree: '',
          variant: TileVariant.Primary,
          sectionLabelId: AfterDeath.None,
          isNewTile: getIsNewNodeById(outgoingTransferNodeId),
        },
      },
      categorizationType: GraphNodeCategorizationType.ExternalSource,
    });
  }

  const incomingTransferSourceNodes = compact(
    getNodes(incomingExternalTransfers).map((transfer) => {
      const transferAmount = transfer.transferValue || new Decimal(0);
      const targetId =
        transfer.destinationEntity?.id ||
        transfer.destinationIndividual?.id ||
        transfer.destinationOrganization?.id;

      if (!targetId) {
        return null;
      }

      if (
        !visibleNodeIds.includes(
          getNodeId({ id: targetId, afterDeath: AfterDeath.None })
        )
      ) {
        return null;
      }

      return {
        transferAmount,
        targetId,
      };
    })
  );
  const incomingTransferNodeId = getExternalTransferNodeId('incoming');

  if (incomingTransferSourceNodes.length) {
    graph.addNodeSafe(incomingTransferNodeId, {
      data: {
        afterDeath: AfterDeath.None,
        id: EXTERNAL_TRANSFER_NODE_IDENTIFIER,
      },
      node: {
        id: incomingTransferNodeId,
        type: 'tile',
        position: { x: 0, y: 0 },
        data: {
          lineOne: 'External source',
          lineTwo: '',
          lineThree: '',
          variant: TileVariant.Primary,
          sectionLabelId: AfterDeath.None,
          isNewTile: getIsNewNodeById(incomingTransferNodeId),
        },
      },
      categorizationType: GraphNodeCategorizationType.ExternalSource,
    });
  }

  const pouroverEdges = viz.edges.filter(
    ({ kind }) => EstateWaterfallEdgeKind.Pourover === kind
  );
  const idsWithIncomingPouroverDispositions = uniq(
    pouroverEdges.map(({ to }) => to.id)
  );
  const idsWithOutgoingPouroverDispositions = uniq(
    pouroverEdges.map(({ from }) => from.id)
  );

  // get the list of IDs that have a transfer attached -- we don't care about order, source/destination, etc.,
  // just that the ID is involved in a transfer
  const idsWithTransfers: string[] = uniq(
    getNodes(hypotheticalTransfers).reduce<string[]>((acc, t) => {
      return acc.concat(
        compact([
          t.sourceEntity?.id,
          t.sourceIndividual?.id,
          t.destinationEntity?.id,
          t.destinationIndividual?.id,
          t.destinationOrganization?.id,
        ])
      );
    }, [])
  );

  // Build the estate tax nodes inclusive of state estate taxes
  const estateTaxSummaries = [firstDeathTaxSummary, secondDeathTaxSummary];
  estateTaxSummaries.forEach((summary, idx) => {
    const afterDeath = idx === 0 ? AfterDeath.First : AfterDeath.Second;
    const id = getTaxNodeId({ afterDeath });
    const estateTaxSumForDeath = getCombinedTaxStateAndFederal(summary);

    if (
      estateTaxSumForDeath?.greaterThan(0) &&
      !visualizationConfig?.hideTaxTiles
    ) {
      const lineOne = 'Estate tax';
      let lineTwo = 'Federal';
      const states = summary?.stateTax?.map(({ stateCode }) => stateCode);

      if (states?.length) {
        lineTwo = `Combined Federal & ${states.join(', ')}`;
      }

      graph.addNodeSafe(id, {
        data: { afterDeath, id: ESTATE_TAX_NODE_IDENTIFIER },
        node: {
          id,
          hidden: visualizationConfig?.hideTaxTiles,
          type: 'tile',
          position: { x: 0, y: 0 },
          data: {
            lineOne,
            lineTwo,
            lineThree: visualizationConfig?.hideEntityValues
              ? undefined
              : formatCurrency(estateTaxSumForDeath, {
                  notation: 'compact',
                }),
            variant: TileVariant.Destructive,
            sectionLabelId: afterDeath,
            isNewTile: getIsNewNodeById(id),
          },
        },
        categorizationType: GraphNodeCategorizationType.EstateTax,
      });
    }
  });

  // Build the gift tax nodes for hypothetical waterfalls
  const giftTaxSummary = beforeFirstDeathTaxSummary;
  if (isHypotheticalWaterfall(waterfall)) {
    const afterDeath = AfterDeath.None;
    const giftTaxSumForDeath = getCombinedTaxStateAndFederal(giftTaxSummary);
    if (
      giftTaxSumForDeath?.greaterThan(0) &&
      !visualizationConfig?.hideTaxTiles
    ) {
      const id = getGiftTaxNodeId({ afterDeath });
      graph.addNodeSafe(id, {
        data: { afterDeath, id },
        node: {
          id,
          hidden: visualizationConfig?.hideTaxTiles,
          type: 'tile',
          position: { x: 0, y: 0 },
          data: {
            lineOne: 'Gift tax',
            lineTwo: '',
            lineThree: visualizationConfig?.hideEntityValues
              ? undefined
              : formatCurrency(giftTaxSumForDeath, {
                  notation: 'compact',
                }),
            variant: TileVariant.Destructive,
            sectionLabelId: afterDeath,
            isNewTile: getIsNewNodeById(id),
          },
        },
        categorizationType: GraphNodeCategorizationType.GiftTax,
      });
    }
  }

  // Add the nodes
  vizNodes.forEach((node) => {
    const nodeId = getNodeId(node);

    const tile = buildTileFromNode({
      visualizationConfig,
      nodeFragment: node,
      firstDeathGrantor,
      secondDeathGrantor,
      isNewTile: getIsNewNodeById(nodeId),
      hasHypotheticalTransfer: idsWithTransfers.includes(node.id),
      hasIncomingPouroverDisposition:
        idsWithIncomingPouroverDispositions.includes(node.id),
      hasOutgoingPouroverDisposition:
        idsWithOutgoingPouroverDispositions.includes(node.id),
    });

    const categorizationType = getCategorizationType(node);
    if (!tile) return;
    if (!categorizationType) return;

    graph.addNodeSafe(nodeId, {
      data: node,
      node: tile,
      categorizationType,
    });

    // Add the outgoing external transfer edge from the node that originated the transfer
    const outgoingTransfer = outgoingTransferSourceNodes.find(
      ({ sourceId }) => sourceId === node.id
    );
    if (outgoingTransfer && node.afterDeath === AfterDeath.None) {
      const source = getNodeId(node);
      const edge = createEdge({
        source,
        target: outgoingTransferNodeId,
        data: {
          variant: 'secondary',
          edgeLabel: {
            variant: 'secondary',
            label: 'Transfer',
            value: formatCurrency(outgoingTransfer.transferAmount, {
              notation: 'compact',
            }),
          },
        },
      });
      graph.addEdgeSafe(source, outgoingTransferNodeId, {
        type: 'default',
        edge,
      });
    }

    // Add the incoming transfer edge from the node that originated the transfer
    const incomingTransfer = incomingTransferSourceNodes.find(
      (transfer) => transfer.targetId === node.id
    );
    if (incomingTransfer && node.afterDeath === AfterDeath.None) {
      const target = getNodeId(node);
      const edge = createEdge({
        source: incomingTransferNodeId,
        target,
        data: {
          variant: 'secondary',
          edgeLabel: {
            variant: 'secondary',
            label: 'Transfer',
            value: formatCurrency(incomingTransfer.transferAmount, {
              notation: 'compact',
            }),
          },
        },
      });
      graph.addEdgeSafe(incomingTransferNodeId, target, {
        type: 'default',
        edge,
      });
    }

    // Add the gift edge from the node that originated the transfer
    if (
      node.transferredToGiftingTax?.toNumber() &&
      !visualizationConfig?.hideTaxTiles
    ) {
      const afterDeath = node.afterDeath;
      const giftTaxNodeId = getGiftTaxNodeId({ afterDeath });

      const source = getNodeId(node);

      addTaxEdge({
        graph,
        source,
        target: giftTaxNodeId,
        value: node.transferredToGiftingTax,
      });
    }

    // Add the edge to the estate tax node. We do this in the node iteration vs. the edge iteration,
    // because some edges are hidden when the value of a tile is depleted to 0, see: https://linear.app/luminary/issue/T1-279
    //
    // Only include the node if there is an actual value, 0 is invalid
    // https://withluminary.slack.com/archives/C05BPQS8CG4/p1692748772395659?thread_ts=1692748503.489729&cid=C05BPQS8CG4
    // TODO (LUM-2026): Remove this when Edu adds it to the backend
    if (
      node.transferredToEstateTax?.toNumber() &&
      !visualizationConfig?.hideTaxTiles
    ) {
      const source = nodeId;

      const taxTarget = getTaxNodeId({
        afterDeath: getEstateTaxAfterDeath(node.afterDeath),
      });

      addTaxEdge({
        graph,
        source,
        target: taxTarget,
        value: node.transferredToEstateTax,
      });
    }
  });

  // Add the grouped nodes + tax transfer edge
  Object.entries(groupedNodesByGroupNodeId).forEach(
    ([groupNodeId, groupNodes]) => {
      const group = assertNonNil(
        visualizationGroupsByGroupNodeId[groupNodeId],
        `No group found for ${groupNodeId}`
      );

      const nodeId = groupNodeId;
      const afterDeath = group.namespace as AfterDeath;
      const categorizationType: GraphNodeCategorizationType =
        GraphNodeCategorizationType.GroupNode;

      const shouldTriggerCreateGroupModal =
        group.displayName === SHOW_CREATE_GROUP_MODAL_SENTINEL;

      const lineOne = (() => {
        if (shouldTriggerCreateGroupModal) return '';
        return group.displayName ?? '';
      })();

      const groupedTiles = compact(
        groupNodes.map((node) => {
          const tile = buildTileFromNode({
            visualizationConfig,
            nodeFragment: node,
            firstDeathGrantor,
            secondDeathGrantor,
          });

          const {
            beforeActionsValue,
            afterActionsValue,
            shouldShowBeforeAndAfterPouroverWills,
            shouldShowBeforeAndAfterTransfers,
          } = getBeforeAndAfterData(node, {
            hasHypotheticalTransfer: idsWithTransfers.includes(node.id),
            hasIncomingPouroverDisposition:
              idsWithIncomingPouroverDispositions.includes(node.id),
            hasOutgoingPouroverDisposition:
              idsWithOutgoingPouroverDispositions.includes(node.id),
          });

          if (!tile) return null;

          const { id, data } = tile;

          return {
            tileId: id,
            nodeId: node.id,
            beforeActionsValue,
            afterActionsValue,
            shouldShowBeforeAndAfterPouroverWills,
            shouldShowBeforeAndAfterTransfers,
            ...data,
          };
        })
      );

      let beforeActionsTotalValue: Decimal | null = null;
      let afterActionsTotalValue: Decimal | null = null;

      const shouldShowBeforeAndAfter = groupedTiles.some(
        ({
          beforeActionsValue,
          shouldShowBeforeAndAfterPouroverWills,
          shouldShowBeforeAndAfterTransfers,
        }) =>
          !!beforeActionsValue &&
          (shouldShowBeforeAndAfterPouroverWills ||
            shouldShowBeforeAndAfterTransfers)
      );

      if (shouldShowBeforeAndAfter) {
        beforeActionsTotalValue = sumDecimalJS(
          compact(
            groupedTiles.map(({ beforeActionsValue }) => beforeActionsValue)
          )
        );
        afterActionsTotalValue =
          sumDecimalJS(
            compact(
              groupedTiles.map(({ afterActionsValue }) => afterActionsValue)
            )
          ) ?? new Decimal(0);
      }

      const shouldShowBeforeAndAfterTile =
        beforeActionsTotalValue && afterActionsTotalValue;

      const variant = getGroupNodeTileVariant(groupNodes);

      // Add the group node
      graph.addNodeSafe(nodeId, {
        data: { afterDeath, id: group.id },
        node: {
          id: nodeId,
          type: 'tile',
          position: { x: 0, y: 0 },
          data: {
            lineOne,
            lineTwo: group.description ?? '',
            lineThree:
              visualizationConfig?.hideEntityValues || beforeActionsTotalValue
                ? undefined
                : formatCurrency(
                    sumDecimalJS(groupNodes.map((node) => node.value)),
                    {
                      notation: 'compact',
                      currencySign: 'accounting',
                    }
                  ),
            variant,
            badgeProps: {
              display: `${groupNodes.length} items`,
              variant:
                getGroupNodeTileVariant(groupNodes) === TileVariant.Group
                  ? BadgeVariants.Gray
                  : BadgeVariants.Primary,
            },
            sectionLabelId: afterDeath,
            group,
            groupedTiles,
            children: (
              <>
                {shouldShowBeforeAndAfterTile && (
                  // beforeActionsTotalValue and afterActionsTotalValue are guaranteed to be non-null here
                  <TileBeforeAndAfter
                    variant={variant}
                    before={formatCurrency(beforeActionsTotalValue!, {
                      notation: 'compact',
                    })}
                    after={formatCurrency(afterActionsTotalValue!, {
                      notation: 'compact',
                    })}
                  />
                )}
                <GroupNodeItems
                  groupNodeId={nodeId}
                  nodes={groupNodes}
                  firstDeathGrantor={firstDeathGrantor}
                  secondDeathGrantor={secondDeathGrantor}
                  shouldTriggerCreateGroupModal={shouldTriggerCreateGroupModal}
                  namespace={afterDeath}
                />
              </>
            ),
          },
        },
        categorizationType,
      });

      const groupNodeIds = groupNodes.map(({ id }) => id);
      const incomingTransferAmountSum = sumDecimalJS(
        incomingTransferSourceNodes
          .filter(({ targetId }) => groupNodeIds.includes(targetId))
          .map(({ transferAmount }) => transferAmount)
      );
      const outgoingTransferAmountSum = sumDecimalJS(
        outgoingTransferSourceNodes
          .filter(({ sourceId }) => groupNodeIds.includes(sourceId))
          .map(({ transferAmount }) => transferAmount)
      );

      if (incomingTransferAmountSum.greaterThan(0)) {
        const incomingEdge = createEdge({
          source: incomingTransferNodeId,
          target: nodeId,
          data: {
            variant: 'secondary',
            edgeLabel: {
              variant: 'secondary',
              label: 'Transfer',
              value: formatCurrency(incomingTransferAmountSum, {
                notation: 'compact',
              }),
            },
          },
        });
        graph.addEdgeSafe(incomingTransferNodeId, nodeId, {
          type: 'default',
          edge: incomingEdge,
        });
      }

      if (outgoingTransferAmountSum.greaterThan(0)) {
        const outgoingEdge = createEdge({
          source: nodeId,
          target: outgoingTransferNodeId,
          data: {
            variant: 'secondary',
            edgeLabel: {
              variant: 'secondary',
              label: 'Transfer',
              value: formatCurrency(outgoingTransferAmountSum, {
                notation: 'compact',
              }),
            },
          },
        });
        graph.addEdgeSafe(nodeId, outgoingTransferNodeId, {
          type: 'default',
          edge: outgoingEdge,
        });
      }

      const transferTaxSum = sumDecimalJS(
        groupNodes.map((n) => n.transferredToEstateTax ?? new Decimal(0))
      );

      const giftTaxSum = sumDecimalJS(
        groupNodes.map((n) => n.transferredToGiftingTax ?? new Decimal(0))
      );

      // Add the transfer tax sum from nodes within the group to the tax tile
      if (transferTaxSum.toNumber() && !visualizationConfig?.hideTaxTiles) {
        const source = nodeId;

        const taxTarget = getTaxNodeId({
          afterDeath: getEstateTaxAfterDeath(afterDeath),
        });

        addTaxEdge({
          graph,
          source,
          target: taxTarget,
          value: transferTaxSum,
        });
      }

      // Add edge from group to associated gift tax node
      if (giftTaxSum.toNumber() && !visualizationConfig?.hideTaxTiles) {
        const source = nodeId;

        const taxTarget = getGiftTaxNodeId({
          afterDeath,
        });

        addTaxEdge({
          graph,
          source,
          target: taxTarget,
          value: giftTaxSum,
        });
      }
    }
  );

  // Draw the edges between the nodes
  drawEdges({
    graph,
    viz,
  });

  graph.setAttribute('firstPrimaryClientDeathId', firstGrantorDeath.id);
  graph.setAttribute('waterfall', waterfall);

  return graph;
}
