import { difference, groupBy, intersection, uniq } from 'lodash';
import { useEffect, useMemo, useRef, useState } from 'react';
import { usePreviousDistinct } from 'react-use';

import { Row } from '../EntityMultiSelectTable.types';

export function useEntityMultiSelectState(
  rows: Row[],
  defaultSelectedRowIds?: string[]
) {
  const [selectedRowIds, setSelectedRowIds] = useState<string[]>(
    defaultSelectedRowIds ?? rows.map((row) => row.id)
  );
  const prevSelectedRowIds = usePreviousDistinct(selectedRowIds);
  const implicitlyDeslectedGroupIdsRef = useRef<string[]>([]);

  const idsForGroup: Record<string, Row[]> = useMemo(() => {
    return groupBy(rows, (row) => row.path[0]);
  }, [rows]);

  // Effect to handle group toggles
  useEffect(() => {
    const groupIds = Object.keys(idsForGroup);
    const prevSelectedGroupIds = intersection(groupIds, prevSelectedRowIds);
    const currentSelectedGroupIds = intersection(groupIds, selectedRowIds);
    const newGroupIds = difference(
      currentSelectedGroupIds,
      prevSelectedGroupIds
    );
    const removedGroupIds = difference(
      prevSelectedGroupIds,
      currentSelectedGroupIds
    );

    if (newGroupIds.length > 0) {
      // Add all children of the new group
      const allIds = [...selectedRowIds];

      newGroupIds.forEach((groupId) => {
        const childrenIds =
          idsForGroup[groupId]
            ?.filter((row) => row.path.length > 1)
            .map((row) => row.id) ?? [];

        allIds.push(...childrenIds);
      });

      setSelectedRowIds(uniq(allIds).sort());
      return;
    }

    if (removedGroupIds.length > 0) {
      // Remove all children of the removed group
      const allIds = [...selectedRowIds];
      const idsToRemove = [] as string[];

      removedGroupIds.forEach((groupId) => {
        if (implicitlyDeslectedGroupIdsRef.current.includes(groupId)) {
          return;
        }

        const childrenIds =
          idsForGroup[groupId]
            ?.filter((row) => row.path.length > 1)
            .map((row) => row.id) ?? [];

        idsToRemove.push(...childrenIds);
      });

      setSelectedRowIds(
        uniq(allIds.filter((id) => !idsToRemove.includes(id))).sort()
      );
      return;
    }

    // Groups that are explicitly selected
    const explicitlySelectedGroupIds = groupIds.filter((groupId) =>
      selectedRowIds.includes(groupId)
    );

    // Groups that are toggled because all children are selected
    const implicitlySelectedGroupIds = groupIds.filter((groupId) =>
      idsForGroup[groupId]
        ?.filter((row) => row.path.length > 1)
        ?.every((row) => selectedRowIds.includes(row.id))
    );

    // Group that are not toggled because not all children are selected
    const implicitlyDeslectedGroupIds = groupIds.filter((groupId) =>
      idsForGroup[groupId]
        ?.filter((row) => row.path.length > 1)
        ?.some((row) => !selectedRowIds.includes(row.id))
    );

    implicitlyDeslectedGroupIdsRef.current = implicitlyDeslectedGroupIds;

    const allSelectedGroupIds = uniq([
      ...explicitlySelectedGroupIds,
      ...implicitlySelectedGroupIds,
    ]).filter((groupId) => !implicitlyDeslectedGroupIds.includes(groupId));

    const allSelectedRowIds = uniq([...selectedRowIds, ...allSelectedGroupIds]);
    const allDeselectedRowIds = uniq([...implicitlyDeslectedGroupIds]);

    setSelectedRowIds(
      allSelectedRowIds.filter((id) => !allDeselectedRowIds.includes(id)).sort()
    );
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [idsForGroup, JSON.stringify(selectedRowIds.sort())]);

  return {
    selectedRowIds,
    setSelectedRowIds,
  };
}
