import type { EntryOrMap, IndexKeys, MapOfEntryOrMap, NestedMap } from '@talos/kyoko';
import { nestedMapSum, nestedMapValues, useObservableValue } from '@talos/kyoko';
import type { ReactNode } from 'react';
import { useMemo } from 'react';
import { map } from 'rxjs';
import { useShowByBalances } from '../hooks/useShowByBalances';
import { useTreasuryManagementContext } from '../providers/TreasuryManagementStateAndTabsProvider';
import type { VisualizationVariant } from '../tokens';
import type { DrillKey, MergedBalance, MergedBalanceIndexKeys } from '../types';
import { getTreasuryManagementPercentageString } from '../utils';
import type { ChartDataPoint, ChartRing, ChartSlice, D3ChartColor } from './D3Chart/types';

export interface ChartData {
  balances: {
    clockwise: NestedMap<MergedBalance>;
    counterClockwise: NestedMap<MergedBalance>;
  };
  highlighteableSliceParts: {
    clockwise: Map<DrillKey<MergedBalance>, Map<string, number>>;
    counterClockwise: Map<DrillKey<MergedBalance>, Map<string, number>>;
    indexKey: IndexKeys<MergedBalance>;
  };
}

interface SliceCreationData {
  datapoints: EntryOrMap<MergedBalance>;
  metadata: SliceCreationMetadata;
  highlighteableParts?: SliceHighlighteableParts;
}

interface SliceCreationMetadata {
  drillTo: string[] | undefined;
  drillToDisplayNames: string[] | undefined;
  entryType: MergedBalanceIndexKeys;
}

interface SliceHighlighteableParts {
  map: Map<DrillKey<MergedBalance>, ChartDataPoint<MergedBalance>>;
  array: ChartDataPoint<MergedBalance>[];
}

export interface UseChartDataParams {
  getValue: (item: MergedBalance) => number;
  getColor: (
    key: DrillKey<MergedBalance>,
    indexKey: IndexKeys<MergedBalance>,
    value: number,
    ring?: ChartRing
  ) => D3ChartColor | undefined;
  getIcon?: (
    key: DrillKey<MergedBalance>
  ) => { renderIcon: (size?: number) => ReactNode; defaultIconSize: number } | undefined;
  getDisplayName: (key: DrillKey<MergedBalance>, indexKey: IndexKeys<MergedBalance>) => string | undefined;
  visualization: VisualizationVariant;
}

// This hook grabs balances from a balances provider and builds the data needed by the D3Chart
export const useChartData = ({ getValue, getColor, getIcon, getDisplayName, visualization }: UseChartDataParams) => {
  const { state } = useTreasuryManagementContext();
  const { showBy, flattenedDrillKeys } = state;
  const balancesNestedMapObs = useShowByBalances(showBy);

  // When a new value is emitted, we have to wrap the map in a new obj in order to trigger change detection each time
  const wrappedBalancesNestedMap = useObservableValue(
    () => balancesNestedMapObs.pipe(map(balancesNestedMap => ({ balancesNestedMap }))),
    [balancesNestedMapObs]
  );

  const renderData = useMemo(() => {
    const getSliceCreationData = (
      key: string,
      value: EntryOrMap<MergedBalance>,
      indexKey: MergedBalanceIndexKeys,
      prefixKey?: string
    ) => {
      const sum = nestedMapSum(value, b => b.netEquivalentAmount);
      const collectHighlighteableParts = value instanceof Map;

      const drillTo = getDrillTo(key, value, prefixKey);
      const drillToDisplayNames = getDrillToDisplayNames(drillTo, indexKeys, flattenedDrillKeys, getDisplayName);

      const data: SliceCreationData = {
        datapoints: value,
        metadata: { drillTo, drillToDisplayNames, entryType: indexKey },
      };

      if (collectHighlighteableParts) {
        const sums = getKeySums(value, highlighteablePartsIndexKey);
        const nextIndexKey = indexKey === 'currency' ? 'marketAccountID' : 'currency';
        const highlighteableParts = getSliceHighlighteableParts(nextIndexKey, sums, getColor);
        data.highlighteableParts = highlighteableParts;
      }

      return { sum, data };
    };

    if (!wrappedBalancesNestedMap) {
      return undefined;
    }

    const { balancesNestedMap } = wrappedBalancesNestedMap;
    const indexKeys = balancesNestedMap.indexKeys;

    // If we're drilled at all, grab the map we're currently drilled into and use that going forward as the data source
    const mapDrilledInto = (
      flattenedDrillKeys.length > 0 ? balancesNestedMap.getMapOrEntry(flattenedDrillKeys) : balancesNestedMap.map
    ) as MapOfEntryOrMap<MergedBalance> | undefined;

    if (mapDrilledInto == null) {
      return undefined;
    }

    const highlighteablePartsIndexKey: IndexKeys<MergedBalance> =
      showBy === 'counterparty' ? 'currency' : 'marketAccountID';

    const cData: Map<string, SliceCreationData> = new Map();
    const ccData: Map<string, SliceCreationData> = new Map();

    const mapDrilledIntoIndexKey = indexKeys[flattenedDrillKeys.length];
    // This for loop splits the balances into clockwise and counter-clockwise maps.
    for (const [key, value] of mapDrilledInto) {
      // Collect all the "slice creation data" for each data point in the map, and depending on the data's sum, place it into either the clockwise or c-clockwise storage map
      const { sum, data } = getSliceCreationData(key, value, mapDrilledIntoIndexKey);

      if (sum === 0) {
        continue;
      }

      sum > 0 ? cData.set(key, data) : ccData.set(key, data);
    }

    // sum all slice creation data within each "direction", clockwise and c-clockwise
    const cSum = sumMapOfSliceCreationDatapoints(cData);
    const ccSum = sumMapOfSliceCreationDatapoints(ccData);

    const clockwiseRenderData: ChartSlice<MergedBalance>[] = [];
    const counterClockwiseRenderData: ChartSlice<MergedBalance>[] = [];

    // Iterate over the two sets of slice creation data and actually create slices of them here.
    // The reason being separating data collection and slice creation into two separate steps is that the slice creation requires
    // knowledge of the sum of its entire "direction". So to create a clockwise slice, we need to know the entire clockwise sum.
    [cData, ccData].forEach((data, i) => {
      const isClockwise = i === 0;

      // Convert all
      for (const [key, value] of data.entries()) {
        const { datapoints, metadata, highlighteableParts } = value;
        const slice = createSlice(
          key,
          datapoints,
          isClockwise ? cSum : ccSum,
          data.size,
          highlighteableParts,
          metadata,
          visualization,
          getValue,
          getColor,
          getDisplayName,
          getIcon
        );

        isClockwise ? clockwiseRenderData.push(slice) : counterClockwiseRenderData.push(slice);
      }
    });

    const clockwise = {
      data: clockwiseRenderData,
      netSum: cSum,
    };

    const counterClockwise = {
      data: counterClockwiseRenderData,
      netSum: ccSum,
    };

    const lastDrillKey = flattenedDrillKeys.length > 0 ? flattenedDrillKeys[flattenedDrillKeys.length - 1] : undefined;
    const centerLabel = lastDrillKey
      ? [getDisplayName(lastDrillKey, indexKeys[flattenedDrillKeys.length - 1]) ?? JSON.stringify(lastDrillKey)]
      : ['Total Portfolio'];

    return {
      metadata: {
        totalNetSum: clockwise.netSum + counterClockwise.netSum,
        centerLabels: centerLabel,
        drilledAllTheWay: flattenedDrillKeys.length === indexKeys.length - 1,
      },
      clockwise,
      counterClockwise,
    };
  }, [
    wrappedBalancesNestedMap,
    flattenedDrillKeys,
    showBy,
    getValue,
    getColor,
    getDisplayName,
    getIcon,
    visualization,
  ]);

  return renderData;
};

function createSlice(
  key: string,
  value: EntryOrMap<MergedBalance>,
  allSum: number,
  nEntries: number,
  highlighteableParts: SliceHighlighteableParts | undefined,
  metadata: SliceCreationMetadata,
  visualization: VisualizationVariant,
  getValue: UseChartDataParams['getValue'],
  getColor: UseChartDataParams['getColor'],
  getDisplayName: UseChartDataParams['getDisplayName'],
  getIcon: UseChartDataParams['getIcon']
) {
  const singleParentSum = nestedMapSum(value, getValue);
  const singleParentAbsoluteSum = nestedMapSum(value, item => Math.abs(getValue(item)));

  const { drillTo, drillToDisplayNames, entryType } = metadata;

  return {
    ...mapToChartDataPoint(
      visualization,
      key,
      singleParentSum,
      singleParentAbsoluteSum,
      nEntries,
      entryType,
      getColor,
      getDisplayName,
      getIcon
    ),
    drillTo,
    drillToDisplayNames,
    percentage: getTreasuryManagementPercentageString(allSum, singleParentSum),
    highlighteableParts,
  };
}

// Iterates over all leaf balances in a nested map and summarizes based on the given index key
const getKeySums = (map: MapOfEntryOrMap<MergedBalance>, key: MergedBalanceIndexKeys): Map<string, number> => {
  const sums = new Map<string, number>();
  const balances = nestedMapValues(map);
  for (const balance of balances) {
    const entryOrZero = sums.get(balance[key]) ?? 0;
    sums.set(balance[key], entryOrZero + balance.netEquivalentAmount);
  }

  return sums;
};

// Creates the "drillTo" property to place onto a ChartSlice. The drillTo property signifies where to drill to upon clicking the ChartSlice
// The function will recursively look for the first acceptable drill path. Currently, if there is just one intermediary node, it will be skipped in this implementation.
// For example, if you drill into a market and there is only one market account within it, it would skip the market account level and do a double-drill for example
function getDrillTo(key: string, value: EntryOrMap<MergedBalance>, prefixKey?: string) {
  let drillTo: DrillKey<MergedBalance>[] | undefined = undefined;
  if (value instanceof Map) {
    if (value.size > 1) {
      drillTo = [key];
    } else {
      // We need to get the drillTo key recursively here
      drillTo = getNearestViableDrillToKeyList(value, key);
    }
  }

  if (drillTo != null && prefixKey != null) {
    drillTo = [prefixKey].concat(drillTo);
  }

  return drillTo;
}

function mapToChartDataPoint(
  renderVariant: VisualizationVariant,
  key: DrillKey<MergedBalance>,
  value: number,
  absoluteValue: number,
  nEntries: number,
  entryType: IndexKeys<MergedBalance>,
  getColor: UseChartDataParams['getColor'],
  getDisplayName: UseChartDataParams['getDisplayName'],
  getIcon: UseChartDataParams['getIcon']
): ChartDataPoint<MergedBalance> {
  return {
    key: key,
    value: value,
    renderValue: getRenderValue(renderVariant, value, nEntries),
    absoluteValue: absoluteValue,
    color: getColor(key, entryType, value, 'inner'),
    displayName: getDisplayName?.(key, entryType),
    icon: getIcon?.(key),
  };
}

//
function getSliceHighlighteableParts(
  entryType?: IndexKeys<MergedBalance>,
  map?: Map<DrillKey<MergedBalance>, number>,
  getColor?: UseChartDataParams['getColor']
): SliceHighlighteableParts | undefined {
  if (!map || !entryType) {
    return undefined;
  }

  const returnMap = new Map<DrillKey<MergedBalance>, ChartDataPoint<MergedBalance>>();
  const returnArray: ChartDataPoint<MergedBalance>[] = [];
  for (const [key, value] of map.entries()) {
    const item = {
      key,
      value,
      renderValue: value,
      absoluteValue: Math.abs(value),
      color: getColor?.(key, entryType, value, 'outer'),
    };
    returnMap.set(key, item);
    returnArray.push(item);
  }

  return {
    map: returnMap,
    array: returnArray,
  };
}

function getRenderValue(variant: VisualizationVariant, value: number, nEntries: number): number {
  value = Math.abs(value);

  switch (variant) {
    case 'Value':
      return value;
    case 'Equal':
      return 1 / nEntries;
    default:
      return value;
  }
}

// Recursively find the nearest viable drillkey, and its trail ofc
function getNearestViableDrillToKeyList<T>(
  entryOrMap: EntryOrMap<T>,
  entryOrMapKey: DrillKey<T>,
  keyTrail: DrillKey<T>[] = []
): DrillKey<T>[] {
  if (entryOrMap instanceof Map) {
    if (entryOrMap.size > 1) {
      // This is where we want to drill to!
      return [...keyTrail, entryOrMapKey];
    } else {
      // There is only 1 item here. Go again!
      if (entryOrMap.size !== 1) {
        return [];
      }

      const [key, value] = [...entryOrMap.entries()][0];
      return getNearestViableDrillToKeyList(value, key, [...keyTrail, entryOrMapKey]);
    }
  } else {
    return [...keyTrail];
  }
}

function getDrillToDisplayNames(
  drillTo: string[] | undefined,
  indexKeys: MergedBalanceIndexKeys[],
  flattenedDrillKeys: string[],
  getDisplayName: UseChartDataParams['getDisplayName']
): string[] | undefined {
  return drillTo?.map((key, i) => getDisplayName?.(key, indexKeys[flattenedDrillKeys.length + i])).compact();
}

function sumMapOfSliceCreationDatapoints(map: Map<string, SliceCreationData>): number {
  return [...map.values()].reduce(
    (sum, data) => (sum += nestedMapSum(data.datapoints, balance => balance.netEquivalentAmount)),
    0
  );
}
