import { invariant } from '@epic-web/invariant';
import {
  logger,
  useBatchedPipe,
  useCurrenciesContext,
  wrapForErrorHandling,
  type WebsocketRequest,
} from '@talos/kyoko';
import { getIncludeCashFilter } from 'hooks/useIncludeCashFilter';
import { useSubAccountRollupMemberships, useSubAccounts } from 'providers';
import { useCallback, useMemo } from 'react';
import { map, pipe, type Observable } from 'rxjs';
import type { PortfolioRiskGridData } from '../types/PortfolioRiskGridData';
import { ROLLUP_GROUP_PREFIX } from '../types/types';
import { buildHierarchicalDataPath, convertToPortfolioRiskGridData } from '../types/utils';
import type { useRequestStateTagging } from '../useRequestStateTagging';
import type { ExtraStateType, FromRiskObsValue, ToObsValue } from './PortfolioManagementProvider.types';

function isBookLevelData(hierarchicalDataPath: string[]) {
  return hierarchicalDataPath.at(-3)?.startsWith(ROLLUP_GROUP_PREFIX);
}

function validateMessages(messages: FromRiskObsValue[]) {
  if (messages.length === 0) {
    return;
  }
  const firstMessage = messages.at(0);
  const lastMessage = messages.at(-1);
  invariant(firstMessage, 'First message is missing');
  invariant(lastMessage, 'Last message is missing');
  const flattenMessageData = messages.flatMap(message => message.data);
  // each data has a BatchIndex, validate that they are in order
  for (let i = 0; i < flattenMessageData.length; i++) {
    if (flattenMessageData[i].BatchIndex !== i) {
      throw new Error(`BatchIndex is not in order, expected ${i} but got ${flattenMessageData[i].BatchIndex}`);
    }
  }
}

export function useSubAccountRiskBatchPipe<TRequest extends Omit<WebsocketRequest, 'tag'>>(
  subscription: Observable<FromRiskObsValue>,
  tagMapRef: ReturnType<typeof useRequestStateTagging<TRequest, ExtraStateType>>['tagMapRef']
) {
  const { currenciesBySymbol } = useCurrenciesContext();
  const { subAccountsByName, subAccountsByID } = useSubAccounts();
  const { rollupMembershipsByChildParent } = useSubAccountRollupMemberships();

  // Filter condition for batch processing
  const batchFilterCondition = useCallback((item: FromRiskObsValue) => {
    // If the last message has HasMore === false, we are done
    return item.data.length === 0 || item.data.at(-1)?.BatchHasMore === false;
  }, []);

  // Mapper for batch processing
  const batchProcessPipe = useMemo(() => {
    return pipe(
      wrapForErrorHandling<FromRiskObsValue[], ToObsValue>({
        wrappedPipe: map((messages): ToObsValue => {
          validateMessages(messages);

          invariant(messages[0].tag, `all messages should have a tag`);
          const tag = messages[0].tag;
          const tagMapEntry = tagMapRef.current.get(tag);

          invariant(tagMapEntry?.extraState, `extraState field required to process risk data`);
          const { showRollupHierarchy, selectedPortfolioId, includeCash, treatStablecoinsAsCash } =
            tagMapEntry.extraState;

          const resultData: PortfolioRiskGridData[] = [];
          const initialMessage = messages.at(0);
          if (!initialMessage) {
            return { data: [] };
          }
          const isRowValidForIncludeCash = getIncludeCashFilter(
            includeCash,
            treatStablecoinsAsCash,
            currenciesBySymbol
          );
          for (const message of messages) {
            for (const row of message.data) {
              const isValidToShow = isRowValidForIncludeCash(row.Asset);
              if (!isValidToShow) {
                continue;
              }
              const hierarchicalDataPaths = buildHierarchicalDataPath({
                contextSubAccountId: selectedPortfolioId ?? -1,
                node: row,
                showRollupHierarchy,
                rollupMembershipsByChildParent,
                subAccountsByName,
                subAccountsByID,
              });
              // Using the hierarchical data paths, we can ensure we only count the data once by
              // skipping aggregation for all but the first data path

              const dataPaths = hierarchicalDataPaths
                .map(item => {
                  return {
                    dataPathJoined: item.join('::'),
                    dataPath: item,
                  };
                })
                .sort((a, b) => {
                  return a.dataPathJoined.localeCompare(b.dataPathJoined);
                })
                .map((item, index) => {
                  return {
                    dataPath: item.dataPath,
                    skipAggregation: index > 0,
                  };
                });
              for (const hierarchicalDataPath of dataPaths) {
                // If we are in rollup hierarchy mode, we only want to show the Book level data
                if (showRollupHierarchy && isBookLevelData(hierarchicalDataPath.dataPath)) {
                  continue;
                }
                const convertedRow = convertToPortfolioRiskGridData(
                  row,
                  hierarchicalDataPath.dataPath,
                  hierarchicalDataPath.skipAggregation
                );
                resultData.push(convertedRow);
              }
            }
          }

          return { ...initialMessage, data: resultData };
        }),
        errorHandler: error => {
          logger.error(
            new Error(
              `Portfolio Management batch reader error (processing will continue): ${
                error instanceof Error ? error.message : error
              }`,
              {
                cause: error,
              }
            )
          );
        },
      })
    );
  }, [currenciesBySymbol, rollupMembershipsByChildParent, subAccountsByID, subAccountsByName, tagMapRef]);

  return useBatchedPipe({
    input: subscription,
    batchCondition: batchFilterCondition,
    batchHandlerPipe: batchProcessPipe,
  });
}
