import { Grid } from '@mui/material';
import {
  DataGrid,
  getGridSingleSelectOperators,
  getGridStringOperators,
  GridColDef,
  GridPaginationModel,
  GridRowSelectionModel,
  GridSortModel,
} from '@mui/x-data-grid';
import { useQuery } from '@tanstack/react-query';
import { getModels } from 'api/platform';
import { getAllStudies } from 'api/study';
import { modelTypesByApiModelValue } from 'components/Pages/Jobs/inferenceFieldsOptions';
import { Model, trainingTypeOptions } from 'interfaces/model';
import { filter, first, includes, isEmpty, map, sortBy } from 'lodash';
import moment from 'moment';
import React, { useEffect, useState } from 'react';
import { useAllCancerTypes } from 'utils/queryHooks/uiConstantsHooks';

const defaultRowsPerPage = 10;

export interface ModelsTableProps {
  modelType: string;
  stainType?: string;
  modelUrlSelected: string;
  onSelect: (model: Model) => void;
  enabled?: boolean;
}

const ModelsTable: React.FC<React.PropsWithChildren<ModelsTableProps>> = ({
  modelType,
  stainType = null,
  modelUrlSelected,
  onSelect,
  enabled = true,
}) => {
  const [filterModel, setFilterModel] = useState({
    items: [],
  });
  const [sortModel, setSortModel] = useState<GridSortModel>([
    {
      field: 'date',
      sort: 'desc',
    },
  ]);

  const [paginationModel, setPaginationModel] = useState<GridPaginationModel>({
    page: 0,
    pageSize: defaultRowsPerPage,
  });

  const getFilterParam = () => {
    return {
      [filterModel.items[0]?.field]: filterModel.items[0]?.value,
    };
  };

  const getSortParam = () => {
    return { sort: sortModel[0].sort === 'asc' ? 1 : -1 };
  };

  const getModelsParams = () => {
    let modelParams: {
      modelType: string;
      [key: string]: any;
    } = {
      modelType,
      ...paginationModel,
    };

    if (stainType) {
      modelParams = {
        ...modelParams,
        stainType,
      };
    }

    if (!isEmpty(filterModel.items)) {
      modelParams = {
        ...modelParams,
        ...getFilterParam(),
      };
    }

    if (!isEmpty(sortModel)) {
      modelParams = {
        ...modelParams,
        ...getSortParam(),
      };
    }

    return modelParams;
  };

  const {
    data: models,
    isLoading,
    isFetching,
  } = useQuery({
    queryKey: ['models', getModelsParams()],
    queryFn: () => getModels(getModelsParams()),
    enabled,
  });

  const filteredModelItems =
    filter(models?.items, (model) => {
      return (
        (modelTypesByApiModelValue[model.meta.modelType]?.text ?? model.meta.modelType) === modelType ||
        includes(model.meta.stainTypes, stainType)
      );
    }) ?? [];

  const { data: allStudies } = useQuery(['allStudies'], getAllStudies, { enabled });
  const { allCancerTypes } = useAllCancerTypes({ enabled });

  // the url/id/modelId suppose to be unique, but in case of duplicates we take the first one
  const selectionModelIds = map(
    filter(filteredModelItems, (item) => item.url === modelUrlSelected),
    (item: Model) => item.id
  );
  const selectionModelId = first(selectionModelIds);

  const selectionModelIdAsArray = selectionModelId ? [selectionModelId] : [];
  const [selectionModel, setSelectionModel] = useState<GridRowSelectionModel>(selectionModelIdAsArray);

  // the user can write the url instead of selecting it from the table, so we need to update the selectionModel
  useEffect(() => {
    setSelectionModel(selectionModelIdAsArray);
  }, [modelUrlSelected]);

  const rowCount = filteredModelItems?.length;

  const equalsOperator = filter(getGridStringOperators(), { value: 'equals' });
  const isOperator = filter(getGridSingleSelectOperators(), { value: 'is' });

  const columns: GridColDef[] = [
    {
      field: 'date',
      headerName: 'Date',
      width: 150,
      filterable: false,
      valueFormatter: (value) => moment(value as string).format('lll'),
    },
    {
      field: 'modelId',
      headerName: 'Model Id',
      width: 80,
      filterOperators: equalsOperator,
      sortable: false,
      valueGetter: (_value, row) => row.meta.modelId,
    },
    {
      field: 'description',
      headerName: 'Description',
      width: 200,
      filterable: false,
      sortable: false,
      valueGetter: (_value, row) => row.meta.description,
    },
    {
      field: 'studyId',
      headerName: 'Study Id',
      width: 150,
      filterOperators: isOperator,
      type: 'singleSelect',
      valueOptions: sortBy(
        map(allStudies, (study) => ({ value: study.id, label: study.name })),
        'label'
      ),
      sortable: false,
      valueGetter: (_value, row) => row.meta.studyId,
    },
    {
      field: 'projectName',
      headerName: 'Project Name',
      width: 100,
      filterable: false,
      sortable: false,
      valueGetter: (_value, row) => row.meta.projectName,
    },
    {
      field: 'stainTypes',
      headerName: 'Stain Types',
      width: 80,
      filterable: false,
      sortable: false,
      valueGetter: (_value, row) => row.meta.stainTypes,
    },
    {
      field: 'cancerTypes',
      headerName: 'Cancer Types',
      width: 90,
      filterOperators: isOperator,
      type: 'singleSelect',
      valueOptions: sortBy(
        map(allCancerTypes, (cancerType) => cancerType.displayName),
        'label'
      ),
      sortable: false,
      valueGetter: (_value, row) => row.meta.tissueTypes,
    },
    {
      field: 'end2end',
      headerName: 'End2End',
      width: 70,
      filterable: false,
      sortable: false,
      valueGetter: (_value, row) => row.meta.end2end,
    },
    {
      field: 'trainingType',
      headerName: 'Training Type',
      width: 110,
      filterOperators: isOperator,
      type: 'singleSelect',
      valueOptions: map(trainingTypeOptions, (trainingType) => ({
        value: trainingType.value,
        label: trainingType.text,
      })),
      sortable: false,
      valueGetter: (_value, row) => row.meta.trainingType,
    },
    {
      field: 'modelType',
      headerName: 'Model Type',
      width: 90,
      filterable: false,
      sortable: false,
      valueGetter: (_value, row) => modelTypesByApiModelValue[row.meta.modelType]?.text ?? row.meta.modelType,
    },
  ];

  return (
    <Grid container direction="column">
      <Grid item container spacing={1} pl={1} pt={1} height="78.5vh">
        <DataGrid
          onRowSelectionModelChange={setSelectionModel}
          rowSelectionModel={selectionModel}
          autoHeight
          loading={isLoading && enabled}
          pagination
          rows={filteredModelItems}
          rowCount={rowCount}
          columns={columns}
          onRowClick={(params) => onSelect(params.row)}
          paginationMode="server"
          sortingMode="server"
          filterMode="server"
          sortModel={sortModel}
          onSortModelChange={setSortModel}
          onFilterModelChange={setFilterModel}
          paginationModel={paginationModel}
          onPaginationModelChange={setPaginationModel}
          pageSizeOptions={[10]}
        />
      </Grid>
    </Grid>
  );
};

export default ModelsTable;
