import { yupResolver } from '@hookform/resolvers/yup';
import { Checkbox, CircularProgress, FormControlLabel, Grid, TextField, Typography } from '@mui/material';
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
import { createJobPreset, getJobPresets } from 'api/jobPreset';
import { getModel, runInference } from 'api/platform';
import { getStainTypeFilteredIds } from 'api/stainTypes';
import { AnnotationAssignmentAutocomplete } from 'components/atoms/AnnotationAssignmentAutocomplete';
import LabelledDropdown from 'components/atoms/Dropdown/LabelledDropdown';
import PresetSection from 'components/atoms/PresetSection';
import { getModelId } from 'components/Pages/CalculateFeatures/utils';
import { BasePreset } from 'interfaces/basePreset';
import { InferenceJob, JobType } from 'interfaces/job';
import { NormalizationConfig } from 'interfaces/jobs/multiplex/normalizationParams';
import { Model } from 'interfaces/model';
import { filter, find, first, includes, isEmpty, join, keyBy, map } from 'lodash';
import moment from 'moment';
import { useSnackbar } from 'notistack';
import React, { useEffect, useState } from 'react';
import { Controller, SubmitHandler, useForm } from 'react-hook-form';
import { humanize } from 'utils/helpers';
import { casesOptions, CasesParams, casesSchema } from 'utils/useCasesParams';
import { encodeQueryParamsUsingSchema, useEncodedFilters } from 'utils/useEncodedFilters';
import { useStainTypeIdToDisplayName } from 'utils/useStainTypeIdToDisplayName';
import * as yup from 'yup';
import {
  channelsToExtractOptions,
  classNamesOptions,
  clearMlMachineTypeOptions,
  dynamicCellDetectionConfigOptions,
  modelTypeCell,
  modelTypeDefect,
  modelTypeRunInferenceOptions,
  modelTypesByApiModelValue,
  modelTypeTsm,
  nonDynamicOption,
} from '../../../../Pages/Jobs/inferenceFieldsOptions';
import { JobWithRebuild } from '../JobWithRebuild';
import { defaultChannelNormalizationConfig, normalizationSchema } from '../Multiplex/RunMultiplexNormalization';
import NormalizationParamsForm from '../Multiplex/RunNormalization/NormalizationParamsForm';
import { OldJobsStep } from '../OldJobsStep';
import { PlatformStepper } from '../PlatformStepper';
import ModelsTable from './ModelsTable';

const SNACK_BAR_KEY_RUN_INFERENCE = 'RUN_INFERENCE';
const SNACK_BAR_KEY_SAVE_PRESET = 'SAVE_PRESET';

export interface RunInferenceStepsProps {
  onClose: () => void;
  jobId?: string;
  casesParams: CasesParams;
}

const defaultOtfNormalizationConfig = {
  active: true,
  normParamsConfig: defaultChannelNormalizationConfig,
};

const defaultNormalizationConfig = {
  active: true,
  loadParamsFromDb: false,
  otfNormalizationConfig: defaultOtfNormalizationConfig,
};

const defaultValues = {
  useDynamicCellDetection: false,
  dynamicCellDetectionConfig: nonDynamicOption.value,
  skipRunExistingArtifacts: true,
  dedupValue: 3,
  roiMask: false,
  clearMlMachineType: 'spot',
  inferenceVmsLimit: -1,
  branch: 'dev',
  assignmentIds: [] as string[],
  classNames: [] as string[],
  channelsToExtract: ['dapi'],
  normalizationConfig: defaultNormalizationConfig,
};

export const RunInferenceSteps: React.FunctionComponent<React.PropsWithChildren<RunInferenceStepsProps>> = ({
  onClose,
  casesParams,
  jobId,
}) => {
  const { queryParams } = useEncodedFilters();
  const { enqueueSnackbar, closeSnackbar } = useSnackbar();
  const [activeStep, setActiveStep] = useState(0);
  const [currentPreset, setCurrentPreset] = useState('');
  const [isStepFailed, setIsStepFailed] = useState<Record<number, boolean>>({});
  const [selectedJobId, setSelectedJobId] = useState<string>(jobId ?? undefined);
  const [isSetModelType, setIsSetModelType] = useState(false);

  const queryClient = useQueryClient();

  const { data: presets, isLoading: isLoadingPresets } = useQuery({
    queryKey: ['jobPresets', { steps: ['run_inference'] }],
    queryFn: ({ signal }) => getJobPresets(['run_inference'], signal),
  });

  const currentValidationSchema = validationSchema[activeStep];

  const methods = useForm<IFormValues>({
    mode: 'onChange',
    resolver: yupResolver(currentValidationSchema),
    defaultValues,
  });

  const {
    reset,
    getValues,
    register,
    handleSubmit,
    control,
    watch,
    setValue,
    formState: { errors },
  } = methods;

  const getEncodedParams = () => {
    return encodeQueryParamsUsingSchema(casesParams, casesSchema, casesOptions);
  };

  const modelId = getModelId(watch('model'));

  const { data: modelData } = useQuery({
    queryKey: ['model', modelId],
    queryFn: ({ signal }) => getModel(modelId, signal),
    retry: false,
    enabled: !isEmpty(watch('model')) && Boolean(modelId),
  });

  const { data: stains, isLoading: isLoadingStains } = useQuery({
    queryKey: ['slidesStainsTypes', getEncodedParams()],
    queryFn: ({ signal }) => getStainTypeFilteredIds(getEncodedParams(), signal),
  });
  const { stainTypeIdToDisplayName, isLoadingStainTypeOptions } = useStainTypeIdToDisplayName();

  // Update stain type dropdown if there is only one stain type.
  // This is to avoid the case where the user has to select the only stain type.
  // UseEffect is used because defaultValues from useForm is not updating the dropdown.
  useEffect(() => {
    if (!isLoadingStains && !isEmpty(stains) && stains.length === 1) {
      reset({
        ...getValues(),
        slideStainType: first(stains),
      });
      if (!isEmpty(jobId)) setIsSelectModelStepFailed();
    }
  }, [stains]);

  const stainsOptions = map(stains, (currStain) => ({
    value: currStain,
    text: stainTypeIdToDisplayName(currStain),
  }));

  const runInferenceMutation = useMutation(runInference, {
    onError: () => {
      enqueueSnackbar('Error occurred, Inference failed', {
        variant: 'error',
      });
    },
    onSuccess: () => {
      enqueueSnackbar('Inference Start', { variant: 'success' });
    },
    onSettled() {
      closeSnackbar(SNACK_BAR_KEY_RUN_INFERENCE);
    },
  });

  const createJobPresetMutation = useMutation(createJobPreset, {
    onError: () => {
      enqueueSnackbar('Error occurred, save job preset failed', {
        variant: 'error',
      });
    },
    onSuccess: () => {
      enqueueSnackbar('job preset saved', { variant: 'success' });
      queryClient.invalidateQueries({
        queryKey: ['jobPresets', { steps: ['run_inference'] }],
      });
    },
    onSettled() {
      closeSnackbar(SNACK_BAR_KEY_SAVE_PRESET);
    },
  });

  const onSubmit: SubmitHandler<IFormValues> = async (data) => {
    runInferenceMutation.mutate({
      ...casesParams,
      configParams: {
        ...data,
        studyId: queryParams.filters?.studyId,
        stainType: getStainingMethod(data.slideStainType),
      },
    });

    enqueueSnackbar({
      variant: 'success',
      message: (
        <Grid container>
          <Grid item>
            <Typography>Waiting for Inference to start</Typography>
          </Grid>
          <Grid item>
            <CircularProgress sx={{ marginLeft: 10 }} color="inherit" size={20} />
          </Grid>
        </Grid>
      ),
      key: SNACK_BAR_KEY_RUN_INFERENCE,
      autoHideDuration: null,
    });

    onClose();
  };

  const onSelectModel = (model: Model) => {
    setValue('model', model.url, { shouldValidate: true });
    setIsSelectModelStepFailed();
  };

  const onSelectTsmModel = (model: Model) => {
    setValue('tsmModel', model.url);
    setIsSelectTsmModelStepFailed();
  };

  const onNormalizationConfigChange = (newNormalizationConfig: NormalizationConfig) => {
    setValue('normalizationConfig.otfNormalizationConfig.normParamsConfig', newNormalizationConfig);
    setIsAdvancedStepFailed();
  };

  const modelTypes = keyBy(modelTypeRunInferenceOptions, 'value');

  const selectPreset = (preset: Partial<BasePreset>) => {
    setCurrentPreset(preset.id);
    const presetData = find(presets, { id: preset.id });

    reset({
      ...presetData?.presetJson,
    });

    setIsNameAndDescriptionStepFailed();
    setIsSelectModelStepFailed();
    setIsCellModelConfigStepFailed();
    setIsSelectTsmModelStepFailed();
    setIsRoiMaskStepFailed();
    setIsAdvancedStepFailed();
  };

  const savePreset = (name: string) => {
    createJobPresetMutation.mutate({
      name,
      stains: watch('slideStainType') ? [watch('slideStainType')] : null,
      steps: ['run_inference'],
      sourceStudyId: queryParams.filters?.studyId,
      presetJson: {
        ...getValues(),
      },
    });
    enqueueSnackbar({
      variant: 'success',
      message: (
        <Grid container>
          <Grid item>
            <Typography>Saving preset...</Typography>
          </Grid>
          <Grid item>
            <CircularProgress sx={{ marginLeft: 10 }} color="inherit" size={20} />
          </Grid>
        </Grid>
      ),
      key: SNACK_BAR_KEY_SAVE_PRESET,
      autoHideDuration: null,
    });
  };

  const checkValidationAndSetIsStepFailed = (stepIndex: number, objectToValidate: Record<string, any>) => {
    validationSchema[stepIndex]
      .validate(objectToValidate)
      .then(() => {
        setIsStepFailed((prev) => ({
          ...prev,
          [stepIndex]: false,
        }));
      })
      .catch(() => {
        setIsStepFailed((prev) => ({
          ...prev,
          [stepIndex]: true,
        }));
      });
  };

  const setIsNameAndDescriptionStepFailed = () => {
    checkValidationAndSetIsStepFailed(2, {
      jobName: watch('jobName'),
      jobDescription: watch('jobDescription'),
      hmName: watch('hmName'),
    });
  };

  const setIsSelectModelStepFailed = () => {
    checkValidationAndSetIsStepFailed(3, {
      slideStainType: watch('slideStainType'),
      model: watch('model'),
      modelType: watch('modelType'),
    });
  };

  const setIsCellModelConfigStepFailed = () => {
    checkValidationAndSetIsStepFailed(4, {
      dedupValue: watch('dedupValue'),
      useDynamicCellDetection: watch('useDynamicCellDetection'),
      dynamicCellDetectionConfig: watch('dynamicCellDetectionConfig'),
    });
  };

  const setIsSelectTsmModelStepFailed = () => {
    checkValidationAndSetIsStepFailed(5, {
      modelType: watch('modelType'),
      tsmModel: watch('tsmModel'),
    });
  };

  const setIsRoiMaskStepFailed = () => {
    checkValidationAndSetIsStepFailed(6, {
      roiMask: watch('roiMask'),
      assignmentIds: watch('assignmentIds'),
      classNames: watch('classNames'),
    });
  };

  const setIsAdvancedStepFailed = () => {
    checkValidationAndSetIsStepFailed(7, {
      skipRunExistingArtifacts: watch('skipRunExistingArtifacts'),
      clearMlMachineType: watch('clearMlMachineType'),
      inferenceVmsLimit: watch('inferenceVmsLimit'),
      branch: watch('branch'),
      normalizationConfig: watch('normalizationConfig'),
    });
  };

  if (!isSetModelType && isEmpty(watch('modelType')) && !isEmpty(modelData?.meta?.modelType)) {
    reset({
      ...getValues(),
      modelType: modelTypesByApiModelValue[modelData.meta.modelType]?.value,
    });

    setIsSelectModelStepFailed();
    setIsSelectTsmModelStepFailed();
    setIsSetModelType(true);
  }

  const onSelectedJobParamChange = (newValue: InferenceJob) => {
    if (isEmpty(newValue)) {
      setSelectedJobId(undefined);
      reset({
        jobName: '',
        jobDescription: '',
        hmName: '',
        slideStainType: getValues('slideStainType'),
        model: '',
        tsmModel: '',
        modelType: '',
        dedupValue: defaultValues.dedupValue,
        roiMask: defaultValues.roiMask,
        assignmentIds: [],
        classNames: [],
        clearMlMachineType: defaultValues.clearMlMachineType,
        inferenceVmsLimit: defaultValues.inferenceVmsLimit,
        branch: defaultValues.branch,
        skipRunExistingArtifacts: defaultValues.skipRunExistingArtifacts,
        useDynamicCellDetection: defaultValues.useDynamicCellDetection,
        dynamicCellDetectionConfig: defaultValues.dynamicCellDetectionConfig,
      });
    } else {
      setSelectedJobId(newValue.id);
      reset({
        jobName: newValue?.name,
        jobDescription: newValue?.description,
        hmName: first(newValue?.params.heatmapNames),
        slideStainType: getValues('slideStainType'),
        model: first(newValue?.params.modelPaths),
        tsmModel: newValue?.params.tissueSegmentationModelOverride,
        modelType: newValue?.params.runType,
        roiMask: Boolean(newValue?.params.tissueMaskFromAnnotations) || defaultValues.roiMask,
        assignmentIds: newValue?.params.assignmentIds,
        classNames: newValue?.params.classNames,
        clearMlMachineType: newValue?.params.clearmlMachineType || defaultValues.clearMlMachineType,
        inferenceVmsLimit: newValue?.params.inferenceVmsLimit || defaultValues.inferenceVmsLimit,
        branch: newValue?.params.branchName || defaultValues.branch,
        skipRunExistingArtifacts: newValue?.params.skipRunExistingArtifacts || defaultValues.skipRunExistingArtifacts,
        dedupValue: Number(newValue?.params.dedupValue) || defaultValues.dedupValue,
        useDynamicCellDetection: newValue?.params.useDynamicCellDetection || defaultValues.useDynamicCellDetection,
        dynamicCellDetectionConfig:
          newValue?.params.dynamicCellDetectionConfig || defaultValues.dynamicCellDetectionConfig,
      });

      setIsNameAndDescriptionStepFailed();
      setIsSelectModelStepFailed();
      setIsSelectTsmModelStepFailed();
      setIsRoiMaskStepFailed();
      setIsAdvancedStepFailed();
    }
  };

  const steps = [
    {
      label: 'Upload Params From Old Job',
      optional: true,
      content: (
        <OldJobsStep
          jobType={JobType.Inference}
          onSelectedJob={onSelectedJobParamChange}
          selectedJobId={selectedJobId}
        />
      ),
    },
    {
      label: 'Presets',
      content: (
        <Grid>
          <PresetSection
            selectedPreset={find(presets, { id: currentPreset })}
            presets={map(
              filter(presets, (preset) => !preset.deletedBy),
              (preset) => ({
                ...preset,
                displayName: (
                  <Grid container direction="row" spacing={1} alignItems="center">
                    <Grid item>
                      <Typography variant="body1">{preset.name}</Typography>
                    </Grid>
                    <Grid item>
                      <Typography variant="body2">{moment(preset.createdAt).format('YYYY-MM-DD HH:mm')}</Typography>
                    </Grid>
                    <Grid item>
                      <Typography variant="caption">
                        (
                        {join(
                          map(preset.stains, (stain) => stainTypeIdToDisplayName(stain)),
                          ', '
                        )}
                        )
                      </Typography>
                    </Grid>
                  </Grid>
                ),
              })
            )}
            onSelectPreset={selectPreset}
            label={
              <>
                Inference presets{' '}
                {isLoadingPresets && <CircularProgress size={18.5} title="Loading calculate features presets..." />}
              </>
            }
            isLoading={isLoadingPresets}
          />
        </Grid>
      ),
    },
    {
      label: 'Name and Description',
      subLabel: watch('jobName'),
      content: (
        <Grid container direction="column" spacing={2}>
          <Grid item>
            <Controller
              control={control}
              name="jobName"
              render={({ field: { onChange } }) => (
                <TextField
                  label="Job Name"
                  {...register('jobName')}
                  onChange={onChange}
                  placeholder="Type Here"
                  error={Boolean(errors['jobName'])}
                  helperText={humanize(errors['jobName']?.message)}
                />
              )}
            />
          </Grid>
          <Grid item>
            <Controller
              control={control}
              name="jobDescription"
              render={({ field: { onChange } }) => (
                <TextField
                  label="Job Description"
                  {...register('jobDescription')}
                  onChange={onChange}
                  placeholder="Type Here"
                  error={Boolean(errors['jobDescription'])}
                  helperText={humanize(errors['jobDescription']?.message)}
                  multiline
                  minRows={4}
                />
              )}
            />
          </Grid>
          <Grid item>
            <Controller
              control={control}
              name="hmName"
              render={({ field: { onChange } }) => (
                <TextField
                  label="HM Name"
                  {...register('hmName')}
                  onChange={onChange}
                  placeholder="Type Here"
                  error={Boolean(errors['hmName'])}
                  helperText={humanize(errors['hmName']?.message)}
                  required
                />
              )}
            />
          </Grid>
        </Grid>
      ),
      onNextOrBackClick: () => {
        setIsNameAndDescriptionStepFailed();
      },
    },
    {
      label: 'Select Model',
      subLabel: watch('model'),
      content: (
        <Grid container direction="column" spacing={2}>
          <Grid item container spacing={2}>
            <Grid item xs={6}>
              <Controller
                control={control}
                name="slideStainType"
                render={({ field: { onChange, value } }) => (
                  <LabelledDropdown
                    key={stainsOptions?.length}
                    disabled={stainsOptions?.length < 2}
                    label="Stain Type Selection"
                    options={stainsOptions}
                    value={(!isLoadingStains && stainsOptions?.length === 1 && first(stainsOptions).value) || value}
                    onOptionSelected={(optionValue) => {
                      onChange(optionValue);
                      setIsSelectModelStepFailed();
                    }}
                    error={Boolean(errors['slideStainType'])}
                    helperText={
                      errors['slideStainType']?.message
                        ? humanize(errors['slideStainType']?.message)
                        : stainsOptions?.length > 1
                        ? 'There is more than one stain type, please select one'
                        : 'required before choosing model'
                    }
                    required
                    loading={isLoadingStains || isLoadingStainTypeOptions}
                  />
                )}
              />
            </Grid>
            <Grid item xs={6}>
              <Controller
                control={control}
                name="modelType"
                render={({ field: { onChange, value } }) => (
                  <LabelledDropdown
                    label="Model Type"
                    options={modelTypeRunInferenceOptions}
                    value={value ?? ''}
                    onOptionSelected={(optionValue) => {
                      onChange(optionValue);
                      setValue('model', '');
                      setIsSelectModelStepFailed();
                      setIsSelectTsmModelStepFailed();
                    }}
                    error={Boolean(errors['modelType'])}
                    helperText={
                      errors['modelType']?.message
                        ? humanize(errors['modelType']?.message)
                        : 'required before choosing model'
                    }
                    required
                  />
                )}
              />
            </Grid>
          </Grid>
          <Grid item>
            <Controller
              control={control}
              name="model"
              render={({ field: { onChange } }) => (
                <TextField
                  value={watch('model') ?? ''}
                  label="Model Url"
                  {...register('model')}
                  onChange={(event) => {
                    onChange(event);
                    setIsSelectModelStepFailed();
                  }}
                  placeholder="Choose model from the table or write here the artifact url"
                  error={Boolean(errors['model'])}
                  helperText={humanize(errors['model']?.message)}
                  required
                />
              )}
            />
          </Grid>
          <Grid item>
            <ModelsTable
              modelType={modelTypes[watch('modelType')]?.apiModelValue}
              stainType={watch('slideStainType')}
              onSelect={onSelectModel}
              modelUrlSelected={watch('model')}
              enabled={!isEmpty(watch('slideStainType')) && !isEmpty(watch('modelType'))}
            />
          </Grid>
        </Grid>
      ),
      onNextOrBackClick: () => {
        setIsSelectModelStepFailed();
        setIsCellModelConfigStepFailed();
      },
    },
    {
      label: 'Cell model configuration',
      optional: Boolean(watch('modelType')) && watch('modelType') !== modelTypeCell.value,
      skip: watch('modelType') !== modelTypeCell.value,
      content: (
        <Grid container direction="column" spacing={2}>
          <Grid item>
            <Controller
              control={control}
              name="dedupValue"
              render={({ field: { onChange } }) => (
                <TextField
                  disabled={watch('modelType') !== modelTypeCell.value}
                  label="Dedup Value"
                  type="number"
                  {...register('dedupValue')}
                  onChange={(event) => {
                    onChange(event);
                    setIsCellModelConfigStepFailed();
                  }}
                  placeholder="Type Here"
                  error={Boolean(errors['dedupValue'])}
                  helperText={
                    errors['dedupValue']?.message
                      ? humanize(errors['dedupValue']?.message)
                      : watch('modelType') !== modelTypeCell.value && 'Only for cell model'
                  }
                />
              )}
            />
          </Grid>
          <Grid item>
            <Controller
              control={control}
              name="useDynamicCellDetection"
              render={({ field: { onChange } }) => (
                <FormControlLabel
                  control={
                    <Checkbox
                      {...register('useDynamicCellDetection')}
                      onChange={(event) => {
                        onChange(event);
                        setIsCellModelConfigStepFailed();
                      }}
                    />
                  }
                  label={
                    <Typography>
                      Use Dynamic Cell Detection{' '}
                      <Typography variant="caption">
                        (If you run with dev, this param will not have an effect, and this selection will be taken from
                        dynamic presets param below)
                      </Typography>
                    </Typography>
                  }
                />
              )}
            />
          </Grid>
          <Grid item>
            <Controller
              control={control}
              name="dynamicCellDetectionConfig"
              render={({ field: { onChange, value } }) => (
                <LabelledDropdown
                  label="Dynamic Cell Detection Config"
                  options={dynamicCellDetectionConfigOptions}
                  value={value ?? ''}
                  onOptionSelected={(optionValue) => {
                    onChange(optionValue);
                    setIsCellModelConfigStepFailed();
                  }}
                  error={Boolean(errors['dynamicCellDetectionConfig'])}
                />
              )}
            />
          </Grid>
        </Grid>
      ),
      onNextOrBackClick: () => {
        setIsCellModelConfigStepFailed();
      },
    },
    {
      label: 'Select TSM Model',
      subLabel: watch('tsmModel'),
      optional: Boolean(watch('modelType')) && watch('modelType') === modelTypeDefect.value,
      content: (
        <Grid container direction="column" spacing={2}>
          <Grid item>
            <Controller
              control={control}
              name="tsmModel"
              render={({ field: { onChange } }) => (
                <TextField
                  value={watch('tsmModel') ?? ''}
                  label="TSM Model Url"
                  {...register('tsmModel')}
                  onChange={(event) => {
                    onChange(event);
                    setIsSelectTsmModelStepFailed();
                  }}
                  placeholder="Choose tsm model from the table or write here the artifact url"
                  error={Boolean(errors['tsmModel'])}
                  helperText={humanize(errors['tsmModel']?.message)}
                  required={watch('modelType') !== modelTypeDefect.value}
                />
              )}
            />
          </Grid>
          <Grid item>
            <ModelsTable
              modelType={modelTypeTsm.apiModelValue}
              onSelect={onSelectTsmModel}
              modelUrlSelected={watch('tsmModel')}
            />
          </Grid>
        </Grid>
      ),
      onNextOrBackClick: () => {
        setIsSelectTsmModelStepFailed();
      },
    },
    {
      label: 'ROI Mask',
      optional: true,
      content: (
        <Grid container direction="column" spacing={2}>
          <Grid item xs={6}>
            <Controller
              control={control}
              name="roiMask"
              render={({ field: { onChange } }) => (
                <FormControlLabel
                  control={
                    <Checkbox
                      {...register('roiMask')}
                      onChange={(event) => {
                        onChange(event);
                        setIsRoiMaskStepFailed();
                      }}
                    />
                  }
                  label="ROI Mask"
                />
              )}
            />
          </Grid>
          <Grid item>
            <Controller
              control={control}
              name="assignmentIds"
              render={({ field: { onChange, value } }) => (
                <AnnotationAssignmentAutocomplete
                  casesParams={casesParams}
                  slideStainType={watch('slideStainType')}
                  multiple
                  disabled={isEmpty(watch('slideStainType')) || !watch('roiMask')}
                  limitTags={1}
                  selectedValue={map(value, Number)}
                  onChange={(event, newValue) => {
                    onChange(map(newValue, 'annotationAssignmentId'));
                    setIsRoiMaskStepFailed();
                  }}
                  textFieldProps={{
                    label: 'Assignments',
                    error: Boolean(errors['assignmentIds']),
                    helperText: !watch('roiMask')
                      ? 'Only when ROI mask is checked'
                      : isEmpty(watch('slideStainType'))
                      ? 'Need to choose stain type'
                      : humanize(errors['assignmentIds']?.message),
                  }}
                />
              )}
            />
          </Grid>
          <Grid item>
            <Controller
              control={control}
              name="classNames"
              render={({ field: { onChange, value } }) => (
                <LabelledDropdown
                  multiple={true}
                  disabled={!watch('roiMask')}
                  label="Class Names"
                  options={classNamesOptions}
                  value={value ?? []}
                  onOptionSelected={(optionValue) => {
                    onChange(optionValue);
                    setIsRoiMaskStepFailed();
                  }}
                  error={Boolean(errors['classNames'])}
                  helperText={
                    !watch('roiMask') ? 'Only when ROI mask is checked' : humanize(errors['classNames']?.message)
                  }
                />
              )}
            />
          </Grid>
        </Grid>
      ),
      onNextOrBackClick: () => {
        setIsRoiMaskStepFailed();
      },
    },
    {
      label: 'Advanced',
      optional: true,
      content: (
        <Grid item container direction="column" spacing={2}>
          <Grid item container direction="column" spacing={2}>
            <Grid item>
              {/* select channelsToExtract (choose between dapi and "hoechst")  */}
              <Controller
                control={control}
                name="channelsToExtract"
                render={({ field: { onChange, value } }) => (
                  <LabelledDropdown
                    label="Channel To Extract"
                    options={channelsToExtractOptions}
                    // in the defects case, there is only one channel to extract
                    value={first(value) ?? ''}
                    onOptionSelected={(optionValue) => {
                      onChange([optionValue]);
                    }}
                    error={Boolean(errors['channelsToExtract'])}
                    helperText={humanize(errors['channelsToExtract']?.message)}
                  />
                )}
              />
            </Grid>
            <Grid item>
              <Controller
                control={control}
                name="normalizationConfig.loadParamsFromDb" // Boolean field name
                render={({ field: { onChange, value } }) => (
                  <FormControlLabel
                    control={
                      <Checkbox
                        checked={!!value} // Ensure it's a boolean
                        onChange={(event) => {
                          const isChecked = event.target.checked;
                          onChange(isChecked);
                          setValue(
                            'normalizationConfig.otfNormalizationConfig.active',
                            !isChecked,
                            { shouldValidate: true, shouldDirty: true } // Ensures validation and dirty tracking
                          );
                        }}
                      />
                    }
                    label="Use pre-normalized values (min_c, max_c)"
                  />
                )}
              />
            </Grid>
            {!watch('normalizationConfig.loadParamsFromDb') && (
              <Grid item>
                <NormalizationParamsForm
                  normalizationConfig={watch('normalizationConfig.otfNormalizationConfig.normParamsConfig') || {}}
                  setNormalizationConfig={onNormalizationConfigChange}
                />
              </Grid>
            )}
          </Grid>
          <Grid item>
            <Controller
              control={control}
              name="skipRunExistingArtifacts"
              render={({ field: { onChange } }) => (
                <FormControlLabel
                  control={
                    <Checkbox
                      defaultChecked
                      {...register('skipRunExistingArtifacts')}
                      onChange={(event) => {
                        onChange(event);
                      }}
                    />
                  }
                  label={
                    <Typography>
                      Skip Existing Artifacts{' '}
                      <Typography variant="caption">
                        (Skip run if artifact exists - if False, an artifact will be created even if already exists )
                      </Typography>
                    </Typography>
                  }
                />
              )}
            />
          </Grid>
          <Grid container item spacing={2}>
            <Grid item xs={4}>
              <Controller
                control={control}
                name="clearMlMachineType"
                render={({ field: { onChange, value } }) => (
                  <LabelledDropdown
                    label="clearMl Machine Type"
                    options={clearMlMachineTypeOptions}
                    value={value}
                    onOptionSelected={(optionValue) => {
                      onChange(optionValue);
                    }}
                    error={Boolean(errors['clearMlMachineType'])}
                    helperText={humanize(errors['clearMlMachineType']?.message)}
                    required
                  />
                )}
              />
            </Grid>
            <Grid item xs={4}>
              <Controller
                control={control}
                name="inferenceVmsLimit"
                render={({ field: { onChange } }) => (
                  <TextField
                    label="Inference Vms Limit"
                    type="number"
                    InputProps={{
                      inputProps: {
                        max: 100,
                        min: -1,
                      },
                    }}
                    {...register('inferenceVmsLimit')}
                    onChange={onChange}
                    placeholder="Type Here"
                    error={Boolean(errors['inferenceVmsLimit'])}
                    helperText={humanize(errors['inferenceVmsLimit']?.message)}
                  />
                )}
              />
            </Grid>
            <Grid item xs={4}>
              <Controller
                control={control}
                name="branch"
                render={({ field: { onChange } }) => (
                  <TextField
                    label="Branch"
                    {...register('branch')}
                    onChange={onChange}
                    placeholder="Type Here"
                    error={Boolean(errors['branch'])}
                    helperText={humanize(errors['branch']?.message)}
                    required
                  />
                )}
              />
            </Grid>
          </Grid>
        </Grid>
      ),
      onNextOrBackClick: () => {
        setIsAdvancedStepFailed();
      },
    },
  ];

  return (
    <JobWithRebuild jobId={jobId} onSelectedJobParamChange={onSelectedJobParamChange}>
      <PlatformStepper
        handleSubmit={handleSubmit(onSubmit)}
        steps={steps}
        setActiveStepForValidation={setActiveStep}
        isStepFailed={isStepFailed}
        handleSaveAsPreset={savePreset}
      />
    </JobWithRebuild>
  );
};

interface OnTheFlyNormalizationConfig {
  active: boolean;
  normParamsConfig?: NormalizationConfig | null;
  channelName?: string | null;
  histogramRoot?: string | null;
}

interface InferenceNormalizationConfig {
  active: boolean;
  loadParamsFromDb: boolean; // If true, loads the per_slide_params from the DB
  // Parameters to find normalization values on the fly. If null, use normalization from DB.
  otfNormalizationConfig?: OnTheFlyNormalizationConfig | null;
}

export interface IFormValues {
  jobName: string;
  jobDescription: string;
  hmName: string;
  slideStainType: string;
  model: string;
  tsmModel: string;
  modelType: string;
  dedupValue: number;
  roiMask: boolean;
  assignmentIds: string[];
  classNames: string[];
  clearMlMachineType: string;
  inferenceVmsLimit: number;
  branch: string;
  skipRunExistingArtifacts: boolean;
  useDynamicCellDetection: boolean;
  dynamicCellDetectionConfig: string;
  channelsToExtract?: string[];
  normalizationConfig?: InferenceNormalizationConfig;
}

const validationSchema = [
  yup.object({}),
  yup.object({}),
  yup.object({
    jobName: yup.string(),
    jobDescription: yup.string(),
    hmName: yup.string().required(),
  }),
  yup.object({
    slideStainType: yup.string().required('Stain Type is a required field'),
    model: yup.string().required(),
    modelType: yup.string().required(),
  }),
  yup.object({
    dedupValue: yup.number().when('modelType', {
      is: modelTypeCell.value,
      then: yup.number().required('Dedup Value is a required field'),
    }),
    useDynamicCellDetection: yup.boolean(),
    dynamicCellDetectionConfig: yup.string(),
  }),
  yup.object({
    tsmModel: yup.string().when('modelType', {
      is: modelTypeDefect.value,
      then: yup.string().notRequired().nullable(),
      otherwise: yup.string().required(),
    }),
  }),
  yup.object({
    roiMask: yup.boolean().required(),
    assignmentIds: yup.array().when('roiMask', {
      is: true,
      then: yup.array().required('Required when roiMask is checked'),
    }),
    classNames: yup.array().when('roiMask', {
      is: true,
      then: yup.array().required('Required when roiMask is checked'),
    }),
  }),
  yup.object({
    skipRunExistingArtifacts: yup.boolean().required(),
    clearMlMachineType: yup.string().required(),
    inferenceVmsLimit: yup.number().required(),
    branch: yup.string().required(),
    normalizationConfig: yup.object({
      active: yup.boolean().required(),
      loadParamsFromDb: yup.boolean().required(),
      otfNormalizationConfig: yup.object({
        active: yup.boolean().required(),
        normParamsConfig: normalizationSchema,
      }),
    }),
  }),
];

const getStainingMethod = (slideStainType: string) => {
  if (includes(['he', 'mplex'], slideStainType)) return slideStainType;
  return 'ihc';
};
