import CheckIcon from '@mui/icons-material/Check';
import ClearIcon from '@mui/icons-material/Clear';
import { Grid, Radio, Typography } from '@mui/material';
import {
  DataGrid,
  getGridSingleSelectOperators,
  getGridStringOperators,
  GridColDef,
  GridPaginationModel,
  GridSortModel,
} from '@mui/x-data-grid';
import { useQuery } from '@tanstack/react-query';
import { filter, includes, isEmpty, map, slice, sortBy, uniqBy } from 'lodash';
import moment from 'moment';
import React, { useMemo, useState } from 'react';

import { getModels } from 'api/platform';
import { getAllStudies } from 'api/study';
import { modelTypesByApiModelValue } from 'components/Pages/Jobs/inferenceFieldsOptions';
import { Model, trainingTypeOptions } from 'interfaces/model';
import { useAllCancerTypes } from 'utils/queryHooks/uiConstantsHooks';
import { useStainTypeIdToDisplayName } from 'utils/useStainTypeIdToDisplayName';

const defaultRowsPerPage = 10;

export interface ModelsTableProps {
  modelType: string;
  stainType?: string;
  studyId?: string;
  trainingType?: string;
  selectedModelUrl: string;
  onSelect: (model: Model) => void;
  enabled?: boolean;
  defaultPageSize?: number;
}

const ModelsTable: React.FC<React.PropsWithChildren<ModelsTableProps>> = ({
  modelType,
  stainType = null,
  trainingType = null,
  studyId,
  selectedModelUrl,
  onSelect,
  enabled = true,
  defaultPageSize = defaultRowsPerPage,
}) => {
  const [filterModel, setFilterModel] = useState({
    items: [],
  });
  const [sortModel, setSortModel] = useState<GridSortModel>([
    {
      field: 'date',
      sort: 'desc',
    },
  ]);

  const [paginationModel, setPaginationModel] = useState<GridPaginationModel>({
    page: 0,
    pageSize: defaultPageSize,
  });

  const getFilterParam = () => {
    return {
      [filterModel.items[0]?.field]: filterModel.items[0]?.value,
    };
  };

  const getSortParam = () => {
    return { sort: sortModel[0].sort === 'asc' ? 1 : -1 };
  };

  const getModelsParams = () => {
    let modelParams: {
      modelType: string;
      [key: string]: any;
    } = {
      modelType,
      ...paginationModel,
    };

    if (stainType) {
      modelParams = {
        ...modelParams,
        stainType,
      };
    }

    if (trainingType) {
      modelParams = {
        ...modelParams,
        trainingType,
      };
    }

    if (studyId) {
      modelParams = {
        ...modelParams,
        studyId,
      };
    }

    if (!isEmpty(filterModel.items)) {
      modelParams = {
        ...modelParams,
        ...getFilterParam(),
      };
    }

    if (!isEmpty(sortModel)) {
      modelParams = {
        ...modelParams,
        ...getSortParam(),
      };
    }

    return modelParams;
  };

  const currentModelParams = getModelsParams();

  const {
    data: models,
    isLoading,
    isFetching,
  } = useQuery({
    queryKey: ['models', currentModelParams],
    queryFn: () => getModels(currentModelParams),
    enabled,
    keepPreviousData: true,
  });

  // the url/id/modelId suppose to be unique, but in case of duplicates we take the first one
  const modelItems = useMemo(() => uniqBy(models?.items, 'url'), [models?.items]);

  const selectionModelIds = useMemo(
    () =>
      map(
        slice(
          filter(modelItems, (item) => item.url === selectedModelUrl),
          0,
          1
        ),
        (item: Model) => item.id
      ),
    [modelItems, selectedModelUrl]
  );

  const rowCount = modelItems?.length;

  const equalsOperator = filter(getGridStringOperators(), { value: 'equals' });
  const isOperator = filter(getGridSingleSelectOperators(), { value: 'is' });

  const { stainTypeIdToDisplayName } = useStainTypeIdToDisplayName();

  const { data: allStudies } = useQuery(['allStudies'], getAllStudies, { enabled });
  const { allCancerTypes } = useAllCancerTypes({ enabled });

  const columns: GridColDef[] = [
    {
      field: 'radiobutton',
      headerName: 'Selected',
      width: 100,
      sortable: false,
      renderCell: (params) => <Radio checked={includes(selectionModelIds, params.id)} value={params.id} />,
    },
    {
      field: 'date',
      headerName: 'Date',
      width: 150,
      filterable: false,
      valueFormatter: (value) => moment(value as string).format('lll'),
    },
    {
      field: 'modelId',
      headerName: 'Model Id',
      width: 80,
      filterOperators: equalsOperator,
      sortable: false,
      valueGetter: (_value, row) => row.meta.modelId,
    },
    {
      field: 'description',
      headerName: 'Description',
      width: 200,
      filterable: false,
      sortable: false,
      valueGetter: (_value, row) => row.meta.description,
    },
    {
      field: 'studyId',
      headerName: 'Study Id',
      width: 150,
      filterOperators: isOperator,
      type: 'singleSelect',
      valueOptions: sortBy(
        map(allStudies, (study) => ({ value: study.id, label: study.name })),
        'label'
      ),
      sortable: false,
      valueGetter: (_value, row) => row.meta.studyId,
    },
    {
      field: 'projectName',
      headerName: 'Project Name',
      width: 100,
      filterable: false,
      sortable: false,
      valueGetter: (_value, row) => row.meta.projectName,
    },
    {
      field: 'stainTypes',
      headerName: 'Stain Types',
      width: 80,
      filterable: false,
      sortable: false,
      valueGetter: (_value, row) => map(row.meta.stainTypes, (stainTypeId) => stainTypeIdToDisplayName(stainTypeId)),
    },
    {
      field: 'cancerTypes',
      headerName: 'Cancer Types',
      width: 90,
      filterOperators: isOperator,
      type: 'singleSelect',
      valueOptions: sortBy(
        map(allCancerTypes, (cancerType) => cancerType.displayName),
        'label'
      ),
      sortable: false,
      valueGetter: (_value, row) => row.meta.tissueTypes,
    },
    {
      field: 'end2end',
      headerName: 'End2End',
      width: 70,
      filterable: false,
      sortable: false,
      display: 'flex',
      valueGetter: (_value, row) => row.meta.end2end,
      renderCell: (params) => (
        <Typography margin="auto" variant="body2">
          {params?.row?.meta?.end2end ? <CheckIcon /> : <ClearIcon />}
        </Typography>
      ),
    },
    {
      field: 'trainingType',
      headerName: 'Training Type',
      width: 110,
      filterOperators: isOperator,
      type: 'singleSelect',
      valueOptions: map(trainingTypeOptions, (trainingTypeOption) => ({
        value: trainingTypeOption.value,
        label: trainingTypeOption.text,
      })),
      sortable: false,
      valueGetter: (_value, row) => row.meta.trainingType,
    },
    {
      field: 'modelType',
      headerName: 'Model Type',
      width: 90,
      filterable: false,
      sortable: false,
      valueGetter: (_value, row) => modelTypesByApiModelValue[row.meta.modelType]?.text ?? row.meta.modelType,
    },
  ];

  return (
    <Grid container direction="column">
      <Grid item container spacing={1} pl={1} pt={1} height="auto" style={{ display: 'flex', flexDirection: 'column' }}>
        <DataGrid
          rowSelectionModel={selectionModelIds}
          loading={(isLoading || isFetching) && enabled}
          pagination
          rows={modelItems}
          rowCount={rowCount}
          columns={columns}
          onRowClick={(params) => onSelect(params.row)}
          paginationMode="server"
          sortingMode="server"
          filterMode="server"
          sortModel={sortModel}
          onSortModelChange={setSortModel}
          onFilterModelChange={setFilterModel}
          paginationModel={paginationModel}
          onPaginationModelChange={setPaginationModel}
          pageSizeOptions={[defaultPageSize]}
        />
      </Grid>
    </Grid>
  );
};

export default ModelsTable;
