import {
  OperationVariables,
  QueryHookOptions,
  QueryResult,
} from '@apollo/client';
import { GridPaginationModel } from '@mui/x-data-grid-pro';
import { isNil } from 'lodash';
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';

import { PageInfoFragment } from '@/graphql/fragments/pageInfo.fragment.generated';
import { InputMaybe, Maybe, Scalars } from '@/types/schema';
import { getNodes } from '@/utils/graphqlUtils';

import { PageSizes } from '../constants';
import { DataTableProps } from '../types';

export interface PaginationQuery {
  __typename?: 'Query';
  paginatedData?: {
    totalCount: number;
    pageInfo: PageInfoFragment;
    edges?: Maybe<Maybe<{ node?: unknown }>[]>;
  };
}

interface InternalPageInfoQueryVariables {
  first: Scalars['Int']['input'];
  after?: InputMaybe<Scalars['Cursor']['input']>;
}

export type PaginationQueryVariables = OperationVariables &
  InternalPageInfoQueryVariables;

export type ExternalPaginationQueryVariables<
  T extends PaginationQueryVariables,
> = Omit<T, keyof InternalPageInfoQueryVariables>;

type QueryHook<
  Query extends PaginationQuery,
  Variables extends PaginationQueryVariables,
> = (
  options: QueryHookOptions<Query, Variables>
) => QueryResult<Query, Variables>;

interface ExtendedQueryOpts {
  pageSize?: number;
}

export type PaginatedTableProps = Pick<
  DataTableProps,
  | 'rowCount'
  | 'onPaginationModelChange'
  | 'pagination'
  | 'paginationMode'
  | 'pageSize'
  | 'paginationModel'
  | 'loading'
>;

// Type to extract edges.node so we can pre-extract nested data and
// prevent needing to do a flatMap in every component that uses this
export type CompactQueryData<
  Query extends PaginationQuery,
  Variables extends PaginationQueryVariables,
> = NonNullable<
  NonNullable<
    NonNullable<
      NonNullable<
        NonNullable<QueryResult<Query, Variables>['data']>['paginatedData']
      >['edges']
    >[number]
  >['node']
>;

/**
 * @description Hook to create props for a server paginated DataTable given an apollo query
 * that makes use of a Relay Pagination Connection. Consumers should namespace their entity Connection,
 * e.g. query GetActivityGroups { paginatedData: activityGroups() { ... } }
 */
export function usePaginatedDataTableQuery<
  Query extends PaginationQuery,
  Variables extends PaginationQueryVariables,
>(
  useQuery: QueryHook<Query, Variables>,
  {
    pageSize = PageSizes.Ten,
    ...queryOpts
  }: QueryHookOptions<
    Query,
    // Omit internal variables first and after, these are driven by the hook and should not be provided by the consumer
    ExternalPaginationQueryVariables<Variables>
  > &
    ExtendedQueryOpts
): [
  PaginatedTableProps,
  Omit<QueryResult<Query, Variables>, 'data'> & {
    data: CompactQueryData<Query, Variables>[];
  },
] {
  const pageToCursor = useRef<Record<number, string>>({});
  const [paginationModel, setPaginationModel] = useState<GridPaginationModel>({
    page: 0,
    pageSize,
  });

  const queryArgs = useMemo<QueryHookOptions<Query, Variables>>(() => {
    const variables: PaginationQueryVariables = {
      first: paginationModel.pageSize,
      after: pageToCursor.current[paginationModel.page - 1],
      ...queryOpts.variables,
    };

    return {
      // Keep data fresh by default when we navigate back to the table from another page
      fetchPolicy: 'cache-and-network',
      ...queryOpts,
      variables,
    } as QueryHookOptions<Query, Variables>;
  }, [paginationModel.page, paginationModel.pageSize, queryOpts]);

  const queryResult = useQuery(queryArgs);
  const { data, loading: loadingQuery } = queryResult;

  // Apollo's loading variable is true for cache-and-network, even when there is a cache hit.
  // We want to only show a loading spinner when there is no cache hit, AND a fetch is occurring.
  // Context: https://github.com/apollographql/apollo-client/issues/8669#issuecomment-1464016804
  const loading = loadingQuery && !data;
  const nextCursor = data?.paginatedData?.pageInfo.endCursor;
  const rowCount = data?.paginatedData?.totalCount;

  // In the delay between network requests, data?.paginatedData?.totalCount will be undefined.
  // This breaks MUI data grid which requires a numerical value, so we cache this value
  // between queries and default it to zero on initial load.
  const [cachedRowCount, setCachedRowCount] = useState(
    data?.paginatedData?.totalCount ?? 0
  );

  const onPaginationModelChange = useCallback((model: GridPaginationModel) => {
    // Allow pagination only when we have a cursor to paginate from, or when we're on the first page
    if (model.page === 0 || pageToCursor.current[model.page - 1]) {
      setPaginationModel(model);
    }
  }, []);

  useEffect(() => {
    setCachedRowCount((prevCount) => (isNil(rowCount) ? prevCount : rowCount));
  }, [rowCount]);

  useEffect(() => {
    if (!loading && nextCursor) {
      pageToCursor.current[paginationModel.page] = nextCursor;
    }
  }, [loading, nextCursor, paginationModel.page]);

  const compactData = useMemo(() => {
    return getNodes(data?.paginatedData) as CompactQueryData<
      Query,
      Variables
    >[];
  }, [data]);

  const tablePaginationProps: PaginatedTableProps = {
    rowCount: cachedRowCount,
    onPaginationModelChange,
    pagination: true,
    paginationMode: 'server',
    pageSize,
    paginationModel,
    loading,
  };

  return [tablePaginationProps, { ...queryResult, data: compactData }];
}
