import { Divider, Grid, Typography } from '@mui/material';
import { useQuery } from '@tanstack/react-query';
import { getProcedures } from 'api/procedures';
import AddChart from 'components/FeaturesDashboard/AddChart';
import {
  allCategoricalKeys,
  ChartKey,
  ChartKeyType,
  convertToChartKeys,
  flattenCohort,
  isCategoricalChartKey,
  isDistanceBasedChartKey,
  isNumericalChartKey,
  slideToFeatures,
} from 'components/FeaturesDashboard/chart.util';
import ControlledChart, { ControlledChartOptions } from 'components/FeaturesDashboard/ControlledChart';
import { DEFAULT_PAGE_SIZE } from 'components/StudyDashboard/ProceduresPage/ProcedurePagination';
import { agentDefaultCharts } from 'constants/defaults';
import { ChartType } from 'interfaces/chart';
import { CohortWithSelectedFeatures } from 'interfaces/cohort_old';
import { Features } from 'interfaces/experimentResults';
import { Procedure } from 'interfaces/procedure';
import {
  concat,
  difference,
  filter,
  find,
  first,
  flatMap,
  includes,
  isEmpty,
  keys,
  last,
  map,
  partition,
  sample,
  slice,
  times,
  union,
  uniq,
} from 'lodash';

import React, { useState } from 'react';
import { BooleanParam, JsonParam, QueryParamConfig, useQueryParam, withDefault } from 'use-query-params';
import { isDistanceBasedFeatureKey } from 'utils/features';
import queryClient from 'utils/queryClient';
import { ExperimentResultsSelection, useEncodedFilters } from 'utils/useEncodedFilters';

interface Props {
  cohort: CohortWithSelectedFeatures;
  cohortAllFeatures: string[];
  isLoading: boolean;
  addChartActions: { key: ChartType; icon: any; name: string }[];
  setCohort: (setCohort: (oldCohort: any) => any) => void;
  layout: 'small' | 'large';
  chartIds?: number[];
  title: string;
  description: string;
  id: number;
}

export interface ChartSection {
  id: number;
  title: string;
  description: string;
  chartIds?: number[];
  isPFS?: boolean;
}

const ChartsSection: React.FunctionComponent<React.PropsWithChildren<Props>> = ({
  cohort,
  cohortAllFeatures,
  isLoading,
  addChartActions,
  setCohort,
  layout,
  chartIds,
  title,
  description,
  id,
}) => {
  const pageSize = DEFAULT_PAGE_SIZE;

  const { generateEncodedParams } = useEncodedFilters({
    experimentResultsSelection: ExperimentResultsSelection.FeatureValues,
  });

  const [existingChartsMap, setExistingChartsMap] = useQueryParam<Record<number, ControlledChartOptions>>(
    'reportCharts',
    ChartConfigurationParam
  );

  const [chartsSections, setChartsSections] = useQueryParam<Record<number, ChartSection>>(
    'chartsSections',
    ChartsSectionsParams
  );

  const currentSection = chartsSections[id];

  const [expandDefaultChartIds, setExpandDefaultChartIds] = useState<number[]>([]);
  const getEncodedParams = ({
    page,
    featureSelection,
    isSingleFeature = false,
  }: {
    page?: number;
    featureSelection?: string[];
    isSingleFeature?: boolean;
  }) => {
    return generateEncodedParams(
      {
        page: page,
        pageSize: pageSize,
      },
      { isAnalysis: true, isSingleFeature: isSingleFeature, featuresSelection: (featureSelection || []).sort() },
      { isAnalysis: BooleanParam, isSingleFeature: BooleanParam }
    );
  };

  const addNewFeatureDataToCohort = (proceduresWithNewFeature: Procedure[]) => {
    const newProcedures = map(cohort.procedures, (procedure) => {
      const correspondingProcedure = find(proceduresWithNewFeature, (p) => p.id === procedure.id);
      procedure.slides = map(procedure?.slides, (slide) => {
        const correspondingSlide = find(correspondingProcedure?.slides, (s) => s.id === slide.id);

        if (correspondingSlide) {
          return {
            ...slide,
            experimentResults: [
              {
                ...first(slide.experimentResults),
                features: {
                  ...slideToFeatures(slide),
                  ...slideToFeatures(correspondingSlide),
                },
              },
            ],
          };
        }

        return slide;
      });

      return procedure;
    });

    setCohort((oldCohort) => ({
      ...oldCohort,
      procedures: newProcedures,
    }));
    return {
      ...cohort,
      procedures: newProcedures,
    };
  };

  const chartTypes: { key: ChartType; name: React.ReactNode }[] = [
    ...map(addChartActions, (action) => ({ key: action.key, name: action.name })),
  ];

  const allFeatureKeys: string[] = cohortAllFeatures || [];
  const cohortFeatures: Features[] = flattenCohort(cohort);
  const keysWithData: string[] = uniq(flatMap(cohortFeatures, (feature) => keys(feature)));
  const keysWithoutData: string[] = difference(allFeatureKeys, keysWithData);

  const [distanceBasedFeatures, otherFeatures] = partition(allFeatureKeys, isDistanceBasedFeatureKey);
  const categoricalFeatureKeys: string[] = map(cohort.inferredFeaturesConfig, 'featureName');
  const numericalFeatures: string[] = difference(otherFeatures, categoricalFeatureKeys);

  const numericalKeys: ChartKey[] = convertToChartKeys(numericalFeatures, ChartKeyType.Numerical);
  const categoricalKeys: ChartKey[] = convertToChartKeys(
    concat(allCategoricalKeys, categoricalFeatureKeys),
    ChartKeyType.Categorical
  );
  const distanceBasedKeys: ChartKey[] = convertToChartKeys(distanceBasedFeatures, ChartKeyType.DistanceBased);

  const availableKeys: ChartKey[] = union(categoricalKeys, numericalKeys, distanceBasedKeys);

  const existingCharts = map(existingChartsMap, (chartOptions, chartId) => ({ id: Number(chartId), chartOptions }));

  const keysInChartsWithoutData: string[] = flatMap(existingCharts, (chart) => {
    return map(
      filter(
        [chart.chartOptions.horizontalKey, chart.chartOptions.verticalKey, chart.chartOptions.splittingKey],
        (key) => key && includes(keysWithoutData, key.name)
      ),
      (key) => key.name
    );
  });

  const doesChartHaveMissingData = (chartOptions: ControlledChartOptions) => {
    for (const key of keysInChartsWithoutData) {
      if (
        chartOptions.horizontalKey?.name === key ||
        chartOptions.verticalKey?.name === key ||
        chartOptions.splittingKey?.name === key
      ) {
        return true;
      }
    }
    return false;
  };

  const encodedParamsForNewFeaturesFetching = getEncodedParams({
    featureSelection: keysInChartsWithoutData,
    isSingleFeature: true,
  });

  const getEncodedParamsOptionsWithFeatures = (addedFeatures: string[], page: number) => {
    const allFeaturesInCohort = uniq(concat(keysWithData, addedFeatures));
    const encodedParamsWithAllFeatures = getEncodedParams({
      featureSelection: allFeaturesInCohort,
      page: page + 1,
    });
    const encodedParamsWithAddedFeatures = getEncodedParams({
      featureSelection: addedFeatures,
      page: page + 1,
    });
    const basicEncodedParams = getEncodedParams({
      page: page + 1,
    });

    return [encodedParamsWithAllFeatures, encodedParamsWithAddedFeatures, basicEncodedParams];
  };

  useQuery({
    queryKey: ['procedures', encodedParamsForNewFeaturesFetching],
    queryFn: () => getProcedures(encodedParamsForNewFeaturesFetching),
    onSuccess: (response: { procedures: Procedure[] }) => {
      const updatedCohort = addNewFeatureDataToCohort(response.procedures);
      const totalPages = updatedCohort?.procedures?.length ? Math.ceil(updatedCohort.procedures.length / pageSize) : 0;
      times(totalPages, (page) => {
        const updatedCohortPage = {
          ...updatedCohort,
          procedures: slice(updatedCohort.procedures, page * pageSize, pageSize),
        };
        for (const encodedParams of getEncodedParamsOptionsWithFeatures(keysWithoutData, page)) {
          queryClient.setQueryData(['procedures', encodedParams], updatedCohortPage);
        }
      });
    },
    enabled: !isEmpty(keysInChartsWithoutData) && !isLoading,
  });

  const removeChart = (chartId: number) => {
    const newCharts = { ...existingChartsMap };
    delete newCharts[chartId];
    setExistingChartsMap(newCharts);
    setChartsSections((currentChartSections) => {
      if (currentSection) {
        return {
          ...currentChartSections,
          [currentSection.id]: {
            ...currentSection,
            chartIds: difference(currentSection.chartIds, [chartId]),
          },
        };
      }
      return currentChartSections;
    });
  };

  const onAddChart = (type: ChartType) => {
    const newChartId = last(existingCharts)?.id + 1 || 0;
    const chartOptions: ControlledChartOptions = {
      type,
      countBy: 'Cases',
      ...(type === ChartType.Histogram && {
        horizontalKey: sample(categoricalKeys),
      }),
      ...(type === ChartType.Pie && {
        categoricalKey: sample(categoricalKeys),
      }),
      ...(type === ChartType.DistanceBased && {
        horizontalKey: sample(distanceBasedKeys),
        filteredCaseIds: first(cohort.procedures)?.id ? [first(cohort.procedures).id] : undefined,
      }),
    };

    const newChartMap = { ...existingChartsMap, [newChartId]: chartOptions };
    setExistingChartsMap(newChartMap);
    setExpandDefaultChartIds((oldExpandDefaultChartIds) => [...oldExpandDefaultChartIds, newChartId]);
    onChangeChartOptions(newChartId, chartOptions, newChartMap);
    setChartsSections((currentChartSections) => {
      if (currentSection) {
        return {
          ...currentChartSections,
          [currentSection.id]: {
            ...currentSection,
            chartIds: [...currentSection.chartIds, newChartId],
          },
        };
      }
      return currentChartSections;
    });
  };

  const getChartOptionsWhenChartTypeChange = (chartOptions: ControlledChartOptions) => {
    if (chartOptions.type === ChartType.Pie) {
      if (isCategoricalChartKey(chartOptions.horizontalKey)) {
        chartOptions.categoricalKey = chartOptions.horizontalKey;
      } else {
        chartOptions.categoricalKey = sample(categoricalKeys);
      }
    } else if (chartOptions.type === ChartType.Histogram) {
      if (chartOptions.categoricalKey) {
        chartOptions.horizontalKey = chartOptions.categoricalKey;
      } else {
        chartOptions.horizontalKey = sample(numericalKeys);
      }
    } else if (chartOptions.type === ChartType.DistanceBased) {
      chartOptions.horizontalKey = sample(distanceBasedKeys);
    }

    return chartOptions;
  };

  const onChangeChartOptions = (
    chartId: number,
    chartOptions: ControlledChartOptions,
    newChartMap: Record<number, ControlledChartOptions> = existingChartsMap
  ) => {
    if (newChartMap[chartId]?.type !== chartOptions?.type) {
      chartOptions = getChartOptionsWhenChartTypeChange(chartOptions);
    }
    // handle numerical key selection in pie chart - convert to histogram chart
    if (
      newChartMap[chartId]?.type === ChartType.Pie &&
      newChartMap[chartId]?.categoricalKey !== chartOptions?.categoricalKey &&
      isNumericalChartKey(chartOptions?.categoricalKey)
    ) {
      chartOptions.type = ChartType.Histogram;
      chartOptions.horizontalKey = chartOptions.categoricalKey;
      delete chartOptions.categoricalKey;
    }

    // handle distance based key selection in non distance based chart - convert to distance based chart
    if (
      newChartMap[chartId]?.type !== ChartType.DistanceBased &&
      newChartMap[chartId]?.categoricalKey !== chartOptions.categoricalKey &&
      isDistanceBasedChartKey(chartOptions?.categoricalKey)
    ) {
      chartOptions.type = ChartType.DistanceBased;
      chartOptions.horizontalKey = chartOptions.categoricalKey;
      delete chartOptions.categoricalKey;
    }

    if (
      newChartMap[chartId]?.type !== ChartType.DistanceBased &&
      newChartMap[chartId]?.verticalKey !== chartOptions?.verticalKey &&
      isDistanceBasedChartKey(chartOptions?.verticalKey)
    ) {
      chartOptions.type = ChartType.DistanceBased;
      chartOptions.horizontalKey = chartOptions.verticalKey;
      delete chartOptions.verticalKey;
    }
    setExistingChartsMap({ ...newChartMap, [chartId]: chartOptions });
  };

  const spacing = layout === 'small' ? 1 : 5;

  return (
    <Grid item container spacing={spacing} py={1}>
      <Grid item xs={12}>
        <Typography variant="h3">{title}</Typography>
        <Typography variant="body1">{description}</Typography>
      </Grid>

      {map(
        filter(existingCharts, (chart) => includes(chartIds?.map(String), chart.id.toString())),
        (chart) => {
          const xs = 12;
          const md = layout === 'small' ? 12 : 12;
          const lg = layout === 'small' ? 12 : 6;
          const xl = layout === 'small' ? 6 : 6;

          return (
            <Grid item xs={xs} md={md} lg={lg} xl={xl} key={chart.id} display="flex">
              <ControlledChart
                id={chart.id}
                key={chart.id}
                availableKeys={availableKeys}
                loading={isLoading || doesChartHaveMissingData(chart.chartOptions)}
                cohorts={[cohort]}
                onRemove={() => removeChart(chart.id)}
                chartOptions={chart.chartOptions}
                onChangeChartOptions={(chartOptions) => onChangeChartOptions(chart.id, chartOptions)}
                chartTypes={chartTypes}
                layout={chart.chartOptions.type === ChartType.KaplanMeier ? 'large' : 'small'}
                expandEditSectionOnRender={includes(expandDefaultChartIds, chart.id)}
              />
            </Grid>
          );
        }
      )}
      <Grid
        item
        xs={12}
        md={layout === 'small' ? 12 : 6}
        lg={layout === 'small' ? 12 : 6}
        xl={layout === 'small' ? 12 : 6}
        sx={{ minHeight: layout === 'large' ? 380 : 300 }}
        display="flex"
      >
        <AddChart cohorts={[]} onClick={onAddChart} actions={addChartActions} />
      </Grid>
      <Grid item xs={12}>
        <Divider />
      </Grid>
    </Grid>
  );
};

export const ChartConfigurationParam: QueryParamConfig<Record<number, ControlledChartOptions>> = withDefault(
  JsonParam,
  agentDefaultCharts
) as QueryParamConfig<Record<number, ControlledChartOptions>>;

export const defaultChartsSections = {
  1: {
    title: 'Cases Overview',
    chartIds: keys(agentDefaultCharts),
  },
};

export const ChartsSectionsParams: QueryParamConfig<Record<number, ChartSection>> = withDefault(
  JsonParam,
  defaultChartsSections
);

export default ChartsSection;
