import { yupResolver } from '@hookform/resolvers/yup';
import { Checkbox, CircularProgress, FormControlLabel, Grid, TextField, Typography } from '@mui/material';
import { useMutation, useQueryClient } from '@tanstack/react-query';
import { first, indexOf, isEmpty, map } from 'lodash';
import { useSnackbar } from 'notistack';
import React, { useState } from 'react';
import { Controller, SubmitHandler, useForm } from 'react-hook-form';
import * as yup from 'yup';

import { createJobPreset } from 'api/jobPreset';
import { runInference } from 'api/platform';
import { AnnotationAssignmentAutocomplete } from 'components/atoms/AnnotationAssignmentAutocomplete';
import LabelledDropdown from 'components/atoms/Dropdown/LabelledDropdown';
import ModelsTable from 'components/ModelsTable';
import {
  channelsToExtractOptions,
  classNamesOptions,
  clearMlMachineTypeOptions,
  defaultDynamicOption,
  dynamicCellDetectionConfigOptions,
  modelTypeCell,
  modelTypeDefect,
  modelTypeTsm
} from 'components/Pages/Jobs/inferenceFieldsOptions';
import { InferenceJob } from 'interfaces/job';
import { JobPreset } from 'interfaces/jobPreset';
import { NormalizationConfig } from 'interfaces/jobs/multiplex/normalizationParams';
import { Model } from 'interfaces/model';
import { getStainingMethodByStainType } from 'interfaces/stainType';
import { humanize } from 'utils/helpers';
import { CasesParams } from 'utils/useCasesParams';
import { useEncodedFilters } from 'utils/useEncodedFilters';
import { JobWithRebuild } from '../JobWithRebuild';
import { defaultChannelNormalizationConfig, normalizationSchema } from '../Multiplex/RunMultiplexNormalization';
import NormalizationParamsForm from '../Multiplex/RunNormalization/NormalizationParamsForm';
import { PlatformStepper } from '../PlatformStepper';
import { SelectModelStep } from '../SelectModelStep';
import { LoadParamsStep } from './LoadParamsStep';

const SNACK_BAR_KEY_RUN_INFERENCE = 'RUN_INFERENCE';
const SNACK_BAR_KEY_SAVE_PRESET = 'SAVE_PRESET';

export interface RunInferenceStepsProps {
  onClose: () => void;
  jobId?: string;
  casesParams: CasesParams;
  debug?: boolean;
}

const defaultOtfNormalizationConfig = {
  active: true,
  normParamsConfig: defaultChannelNormalizationConfig,
};

const defaultNormalizationConfig = {
  active: true,
  loadParamsFromDb: false,
  otfNormalizationConfig: defaultOtfNormalizationConfig,
};

const defaultValues = {
  useDynamicCellDetection: true,
  dynamicCellDetectionConfig: defaultDynamicOption.value,
  skipRunExistingArtifacts: true,
  dedupValue: 3,
  roiMask: false,
  clearMlMachineType: 'spot',
  inferenceVmsLimit: -1,
  branch: 'dev',
  assignmentIds: [] as string[],
  classNames: [] as string[],
  channelsToExtract: ['dapi'],
  normalizationConfig: defaultNormalizationConfig,
};

enum RunInferenceStep {
  LoadParams = 'Load Params',
  NameAndDescription = 'Name and Description',
  SelectModel = 'Select Model',
  CellModelConfig = 'Cell Model Config',
  SelectTsmModel = 'Select TSM Model',
  RoiMask = 'ROI Mask',
  Advanced = 'Advanced',
}

const runInferenceStepsOrder = [
  RunInferenceStep.LoadParams,
  RunInferenceStep.NameAndDescription,
  RunInferenceStep.SelectModel,
  RunInferenceStep.SelectTsmModel,
  RunInferenceStep.RoiMask,
  RunInferenceStep.CellModelConfig,
  RunInferenceStep.Advanced,
];

export const RunInferenceSteps: React.FunctionComponent<React.PropsWithChildren<RunInferenceStepsProps>> = ({
  onClose,
  casesParams,
  jobId,
  debug = true,
}) => {
  const { queryParams } = useEncodedFilters();
  const { enqueueSnackbar, closeSnackbar } = useSnackbar();
  const [activeStep, setActiveStep] = useState(0);
  const [isStepFailed, setIsStepFailed] = useState<Record<number, boolean>>({});

  const queryClient = useQueryClient();

  const currentValidationSchema = validationSchema[runInferenceStepsOrder[activeStep]];

  const formMethods = useForm<IFormValues>({
    mode: 'onChange',
    resolver: yupResolver(currentValidationSchema),
    defaultValues,
  });

  const {
    reset,
    getValues,
    register,
    handleSubmit,
    control,
    watch,
    setValue,
    formState: { errors },
  } = formMethods;

  const runInferenceMutation = useMutation(runInference, {
    onError: () => {
      enqueueSnackbar('Error occurred, Inference failed', {
        variant: 'error',
      });
    },
    onSuccess: () => {
      enqueueSnackbar('Inference Start', { variant: 'success' });
      onClose();
    },
    onSettled() {
      closeSnackbar(SNACK_BAR_KEY_RUN_INFERENCE);
    },
  });

  const createJobPresetMutation = useMutation(createJobPreset, {
    onError: () => {
      enqueueSnackbar('Error occurred, save job preset failed', { variant: 'error' });
    },
    onSuccess: () => {
      enqueueSnackbar('job preset saved', { variant: 'success' });
      queryClient.invalidateQueries({
        queryKey: ['jobPresets', { steps: ['run_inference'] }],
      });
    },
    onSettled() {
      closeSnackbar(SNACK_BAR_KEY_SAVE_PRESET);
    },
  });

  const onSubmit: SubmitHandler<IFormValues> = async (data) => {
    enqueueSnackbar({
      variant: 'success',
      message: (
        <Grid container>
          <Grid item>
            <Typography>Waiting for Inference to start</Typography>
          </Grid>
          <Grid item>
            <CircularProgress sx={{ marginLeft: 10 }} color="inherit" size={20} />
          </Grid>
        </Grid>
      ),
      key: SNACK_BAR_KEY_RUN_INFERENCE,
      autoHideDuration: null,
    });

    await runInferenceMutation.mutateAsync({
      ...casesParams,
      configParams: {
        ...data,
        studyId: queryParams.filters?.studyId,
        // TODO: discuss changing this name with platform team, since this is confusing
        // This is not a stain type, but a derived value from the slide
        stainType: getStainingMethodByStainType(data.slideStainType),
      },
    });
  };

  const onSelectTsmModel = (model: Model) => {
    setValue('tsmModel', model.url);
    setIsSelectTsmModelStepFailed();
  };

  const onNormalizationConfigChange = (newNormalizationConfig: NormalizationConfig) => {
    setValue('normalizationConfig.otfNormalizationConfig.normParamsConfig', newNormalizationConfig);
    setIsAdvancedStepFailed();
  };

  const [currentPreset, setCurrentPreset] = useState('');
  const [selectedJobIdState, setSelectedJobId] = useState<string>(undefined);
  const selectedJobId = jobId || selectedJobIdState;
  const [loadedFrom, setLoadedFrom] = useState<string>('');

  const selectPreset = (presetData: JobPreset) => {
    setCurrentPreset(presetData.id);
    setLoadedFrom(presetData?.name || presetData?.id || '');
    reset({
      ...(presetData?.presetJson || defaultValues),
      // Make sure default values are set for normalizationConfig
      normalizationConfig: {
        ...defaultNormalizationConfig,
        ...(presetData?.presetJson?.normalization || {}),
        otfNormalizationConfig: {
          ...defaultOtfNormalizationConfig,
          ...(presetData?.presetJson?.normalization?.otfNormalizationConfig || {}),
          normParamsConfig: {
            ...defaultChannelNormalizationConfig,
            ...(presetData?.presetJson?.normalization?.otfNormalizationConfig?.normParamsConfig || {}),
          },
        },
      },
    });

    setIsNameAndDescriptionStepFailed();
    setIsSelectModelStepFailed();
    setIsCellModelConfigStepFailed();
    setIsSelectTsmModelStepFailed();
    setIsRoiMaskStepFailed();
    setIsAdvancedStepFailed();
  };

  const savePreset = (name: string) => {
    const slideStainType = watch('slideStainType');
    createJobPresetMutation.mutate({
      name,
      stains: slideStainType ? [slideStainType] : null,
      steps: ['run_inference'],
      sourceStudyId: queryParams.filters?.studyId,
      presetJson: { ...getValues() },
    });
    enqueueSnackbar({
      variant: 'success',
      message: (
        <Grid container>
          <Grid item>
            <Typography>Saving preset...</Typography>
          </Grid>
          <Grid item>
            <CircularProgress sx={{ marginLeft: 10 }} color="inherit" size={20} />
          </Grid>
        </Grid>
      ),
      key: SNACK_BAR_KEY_SAVE_PRESET,
      autoHideDuration: null,
    });
  };

  const checkValidationAndSetIsStepFailed = (step: RunInferenceStep, objectToValidate: Record<string, any>) => {
    const stepIndex = indexOf(runInferenceStepsOrder, step);
    try {
      validationSchema[step].validateSync(objectToValidate);
      setIsStepFailed((prev) => ({ ...prev, [stepIndex]: false }));
    } catch (error) {
      console.error(`Step ${stepIndex + 1} (${step}) failed: ${error}`);
      if (debug) {
        console.error(error);
        console.debug({ objectToValidate, schema: validationSchema[step] });
      }
      setIsStepFailed((prev) => ({ ...prev, [stepIndex]: true }));
    }
  };

  const setIsNameAndDescriptionStepFailed = () => {
    checkValidationAndSetIsStepFailed(RunInferenceStep.NameAndDescription, {
      jobName: watch('jobName'),
      jobDescription: watch('jobDescription'),
      hmName: watch('hmName'),
    });
  };

  const setIsSelectModelStepFailed = () => {
    checkValidationAndSetIsStepFailed(RunInferenceStep.SelectModel, {
      slideStainType: watch('slideStainType'),
      model: watch('model'),
      modelType: watch('modelType'),
    });
  };

  const setIsCellModelConfigStepFailed = () => {
    checkValidationAndSetIsStepFailed(RunInferenceStep.CellModelConfig, {
      dedupValue: watch('dedupValue'),
      useDynamicCellDetection: watch('useDynamicCellDetection'),
      dynamicCellDetectionConfig: watch('dynamicCellDetectionConfig'),
    });
  };

  const setIsSelectTsmModelStepFailed = () => {
    checkValidationAndSetIsStepFailed(RunInferenceStep.SelectTsmModel, {
      modelType: watch('modelType'),
      tsmModel: watch('tsmModel'),
    });
  };

  const setIsRoiMaskStepFailed = () => {
    checkValidationAndSetIsStepFailed(RunInferenceStep.RoiMask, {
      roiMask: watch('roiMask'),
      assignmentIds: watch('assignmentIds'),
      classNames: watch('classNames'),
    });
  };

  const setIsAdvancedStepFailed = () => {
    checkValidationAndSetIsStepFailed(RunInferenceStep.Advanced, {
      skipRunExistingArtifacts: watch('skipRunExistingArtifacts'),
      clearMlMachineType: watch('clearMlMachineType'),
      inferenceVmsLimit: watch('inferenceVmsLimit'),
      branch: watch('branch'),
      normalizationConfig: watch('normalizationConfig'),
    });
  };

  const onSelectedJobParamChange = (newValue: InferenceJob) => {
    setSelectedJobId(newValue?.id);
    setLoadedFrom(newValue?.name || newValue?.id || '');
    if (isEmpty(newValue)) {
      reset({
        jobName: '',
        jobDescription: '',
        hmName: '',
        slideStainType: getValues('slideStainType'),
        model: '',
        tsmModel: '',
        modelType: '',
        dedupValue: defaultValues.dedupValue,
        roiMask: defaultValues.roiMask,
        assignmentIds: [],
        classNames: [],
        clearMlMachineType: defaultValues.clearMlMachineType,
        inferenceVmsLimit: defaultValues.inferenceVmsLimit,
        branch: defaultValues.branch,
        skipRunExistingArtifacts: defaultValues.skipRunExistingArtifacts,
        useDynamicCellDetection: defaultValues.useDynamicCellDetection,
        dynamicCellDetectionConfig: defaultValues.dynamicCellDetectionConfig,
        normalizationConfig: defaultValues.normalizationConfig,
        channelsToExtract: defaultValues.channelsToExtract,
      });
    } else {
      reset({
        jobName: newValue?.name,
        jobDescription: newValue?.description,
        hmName: first(newValue?.params.heatmapNames),
        slideStainType: getValues('slideStainType'),
        model: first(newValue?.params.modelPaths),
        tsmModel: newValue?.params.tissueSegmentationModelOverride,
        modelType: newValue?.params.runType,
        roiMask: Boolean(newValue?.params.tissueMaskFromAnnotations) || defaultValues.roiMask,
        assignmentIds: newValue?.params.assignmentIds,
        classNames: newValue?.params.classNames,
        clearMlMachineType: newValue?.params.clearmlMachineType || defaultValues.clearMlMachineType,
        inferenceVmsLimit: newValue?.params.inferenceVmsLimit || defaultValues.inferenceVmsLimit,
        branch: newValue?.params.branchName || defaultValues.branch,
        skipRunExistingArtifacts: newValue?.params.skipRunExistingArtifacts || defaultValues.skipRunExistingArtifacts,
        dedupValue: Number(newValue?.params.dedupValue) || defaultValues.dedupValue,
        useDynamicCellDetection: newValue?.params.useDynamicCellDetection || defaultValues.useDynamicCellDetection,
        dynamicCellDetectionConfig:
          newValue?.params.dynamicCellDetectionConfig || defaultValues.dynamicCellDetectionConfig,
        normalizationConfig: newValue?.params.normalizationConfig || defaultValues.normalizationConfig,
        channelsToExtract: newValue?.params.channelsToExtract || defaultValues.channelsToExtract,
      });

      setIsNameAndDescriptionStepFailed();
      setIsSelectModelStepFailed();
      setIsSelectTsmModelStepFailed();
      setIsRoiMaskStepFailed();
      setIsAdvancedStepFailed();
    }
  };

  const stepsMap = {
    [RunInferenceStep.LoadParams]: {
      label: RunInferenceStep.LoadParams,
      subLabel: loadedFrom,
      content: (
        <LoadParamsStep
          currentPreset={currentPreset}
          onSelectPreset={selectPreset}
          selectedJobId={selectedJobId}
          onSelectJob={onSelectedJobParamChange}
        />
      ),
    },
    [RunInferenceStep.NameAndDescription]: {
      label: RunInferenceStep.NameAndDescription,
      subLabel: watch('jobName'),
      content: (
        <Grid item container direction="column" spacing={2}>
          <Grid item>
            <Controller
              control={control}
              name="jobName"
              render={({ field: { onChange } }) => (
                <TextField
                  label="Job Name"
                  {...register('jobName')}
                  onChange={onChange}
                  placeholder="Type Here"
                  error={Boolean(errors['jobName'])}
                  helperText={humanize(errors['jobName']?.message)}
                />
              )}
            />
          </Grid>
          <Grid item>
            <Controller
              control={control}
              name="jobDescription"
              render={({ field: { onChange } }) => (
                <TextField
                  label="Job Description"
                  {...register('jobDescription')}
                  onChange={onChange}
                  placeholder="Type Here"
                  error={Boolean(errors['jobDescription'])}
                  helperText={humanize(errors['jobDescription']?.message)}
                  multiline
                  minRows={4}
                />
              )}
            />
          </Grid>
          <Grid item>
            <Controller
              control={control}
              name="hmName"
              render={({ field: { onChange } }) => (
                <TextField
                  label="HM Name"
                  {...register('hmName')}
                  onChange={onChange}
                  placeholder="Type Here"
                  error={Boolean(errors['hmName'])}
                  helperText={humanize(errors['hmName']?.message)}
                  required
                />
              )}
            />
          </Grid>
        </Grid>
      ),
      onNextOrBackClick: () => {
        setIsNameAndDescriptionStepFailed();
      },
    },
    [RunInferenceStep.SelectModel]: {
      label: RunInferenceStep.SelectModel,
      subLabel: watch('model'),
      content: (
        <SelectModelStep<IFormValues, 'model', 'modelType', 'slideStainType'>
          jobId={jobId}
          casesParams={casesParams}
          debug={debug}
          onSelectModelUrl={setIsSelectModelStepFailed}
          onSelectStainType={() => {
            if (!isEmpty(jobId)) {
              setIsSelectModelStepFailed();
            }
          }}
          onSelectModelType={() => {
            setIsSelectModelStepFailed();
            setIsSelectTsmModelStepFailed();
          }}
          formMethods={formMethods as any}
          modelTypeField="modelType"
          modelUrlField="model"
          slideStainTypeField="slideStainType"
        />
      ),
      onNextOrBackClick: () => {
        setIsSelectModelStepFailed();
        setIsCellModelConfigStepFailed();
      },
    },
    [RunInferenceStep.CellModelConfig]: {
      label: RunInferenceStep.CellModelConfig,
      optional: true,
      skip: watch('modelType') !== modelTypeCell.value,
      content: (
        <Grid container direction="column" spacing={2}>
          <Grid item>
            <Controller
              control={control}
              name="dedupValue"
              render={({ field: { onChange } }) => (
                <TextField
                  disabled={watch('modelType') !== modelTypeCell.value}
                  label="Dedup Value"
                  type="number"
                  {...register('dedupValue')}
                  onChange={(event) => {
                    onChange(event);
                    setIsCellModelConfigStepFailed();
                  }}
                  placeholder="Type Here"
                  error={Boolean(errors['dedupValue'])}
                  helperText={
                    errors['dedupValue']?.message
                      ? humanize(errors['dedupValue']?.message)
                      : watch('modelType') !== modelTypeCell.value && 'Only for cell model'
                  }
                />
              )}
            />
          </Grid>
          <Grid item>
            <Controller
              control={control}
              name="useDynamicCellDetection"
              render={({ field: { onChange, value } }) => (
                <FormControlLabel
                  control={
                    <Checkbox
                      {...register('useDynamicCellDetection')}
                      checked={value}
                      onChange={(event) => {
                        onChange(event);
                        setIsCellModelConfigStepFailed();
                      }}
                    />
                  }
                  label={
                    <Typography>
                      Use Dynamic Cell Detection{' '}
                      <Typography variant="caption">
                        (If you run with dev, this param will not have an effect, and this selection will be taken from
                        dynamic presets param below)
                      </Typography>
                    </Typography>
                  }
                />
              )}
            />
          </Grid>
          <Grid item>
            <Controller
              control={control}
              name="dynamicCellDetectionConfig"
              render={({ field: { onChange, value } }) => (
                <LabelledDropdown
                  label="Dynamic Cell Detection Config"
                  options={dynamicCellDetectionConfigOptions}
                  value={value || ''}
                  onOptionSelected={(optionValue) => {
                    onChange(optionValue);
                    setIsCellModelConfigStepFailed();
                  }}
                  error={Boolean(errors['dynamicCellDetectionConfig'])}
                />
              )}
            />
          </Grid>
        </Grid>
      ),
      onNextOrBackClick: () => {
        setIsCellModelConfigStepFailed();
      },
    },
    [RunInferenceStep.SelectTsmModel]: {
      label: RunInferenceStep.SelectTsmModel,
      subLabel: watch('tsmModel'),
      optional: Boolean(watch('modelType')) && watch('modelType') === modelTypeDefect.value,
      content: (
        <Grid container direction="column" spacing={2}>
          <Grid item>
            <Controller
              control={control}
              name="tsmModel"
              render={({ field: { onChange } }) => (
                <TextField
                  value={watch('tsmModel') || ''}
                  label="TSM Model Url"
                  {...register('tsmModel')}
                  onChange={(event) => {
                    onChange(event);
                    setIsSelectTsmModelStepFailed();
                  }}
                  placeholder="Choose tsm model from the table or write here the artifact url"
                  error={Boolean(errors['tsmModel'])}
                  helperText={humanize(errors['tsmModel']?.message)}
                  required={watch('modelType') !== modelTypeDefect.value}
                />
              )}
            />
          </Grid>
          <Grid item>
            <ModelsTable
              modelType={modelTypeTsm.apiModelValue}
              onSelect={onSelectTsmModel}
              selectedModelUrl={watch('tsmModel')}
            />
          </Grid>
        </Grid>
      ),
      onNextOrBackClick: () => {
        setIsSelectTsmModelStepFailed();
      },
    },
    [RunInferenceStep.RoiMask]: {
      label: RunInferenceStep.RoiMask,
      optional: true,
      content: (
        <Grid container direction="column" spacing={2}>
          <Grid item xs={6}>
            <Controller
              control={control}
              name="roiMask"
              render={({ field: { onChange } }) => (
                <FormControlLabel
                  control={
                    <Checkbox
                      {...register('roiMask')}
                      onChange={(event) => {
                        onChange(event);
                        setIsRoiMaskStepFailed();
                      }}
                    />
                  }
                  label="ROI Mask"
                />
              )}
            />
          </Grid>
          <Grid item>
            <Controller
              control={control}
              name="assignmentIds"
              render={({ field: { onChange, value } }) => (
                <AnnotationAssignmentAutocomplete
                  casesParams={casesParams}
                  slideStainType={watch('slideStainType')}
                  multiple
                  disabled={isEmpty(watch('slideStainType')) || !watch('roiMask')}
                  limitTags={1}
                  selectedValue={map(value, Number)}
                  onChange={(event, newValue) => {
                    onChange(map(newValue, 'annotationAssignmentId'));
                    setIsRoiMaskStepFailed();
                  }}
                  textFieldProps={{
                    label: 'Assignments',
                    error: Boolean(errors['assignmentIds']),
                    helperText: !watch('roiMask')
                      ? 'Only when ROI mask is checked'
                      : isEmpty(watch('slideStainType'))
                        ? 'Need to choose stain type'
                        : humanize(errors['assignmentIds']?.message),
                  }}
                />
              )}
            />
          </Grid>
          <Grid item>
            <Controller
              control={control}
              name="classNames"
              render={({ field: { onChange, value } }) => (
                <LabelledDropdown
                  multiple={true}
                  disabled={!watch('roiMask')}
                  label="Class Names"
                  options={classNamesOptions}
                  value={value || []}
                  onOptionSelected={(optionValue) => {
                    onChange(optionValue);
                    setIsRoiMaskStepFailed();
                  }}
                  error={Boolean(errors['classNames'])}
                  helperText={
                    !watch('roiMask') ? 'Only when ROI mask is checked' : humanize(errors['classNames']?.message)
                  }
                />
              )}
            />
          </Grid>
        </Grid>
      ),
      onNextOrBackClick: () => {
        setIsRoiMaskStepFailed();
      },
    },
    [RunInferenceStep.Advanced]: {
      label: RunInferenceStep.Advanced,
      optional: true,
      content: (
        <Grid item container direction="column" spacing={2}>
          <Grid item container direction="column" spacing={2}>
            <Grid item>
              {/* select channelsToExtract (choose between dapi and "hoechst")  */}
              <Controller
                control={control}
                name="channelsToExtract"
                render={({ field: { onChange, value } }) => (
                  <LabelledDropdown
                    label="Channel To Extract"
                    options={channelsToExtractOptions}
                    // in the defects case, there is only one channel to extract
                    value={first(value) ?? ''}
                    onOptionSelected={(optionValue) => {
                      onChange([optionValue]);
                    }}
                    error={Boolean(errors['channelsToExtract'])}
                    helperText={humanize(errors['channelsToExtract']?.message)}
                  />
                )}
              />
            </Grid>
            <Grid item>
              <Controller
                control={control}
                name="normalizationConfig.loadParamsFromDb" // Boolean field name
                render={({ field: { onChange, value } }) => (
                  <FormControlLabel
                    control={
                      <Checkbox
                        checked={!!value} // Ensure it's a boolean
                        onChange={(event) => {
                          const isChecked = event.target.checked;
                          onChange(isChecked);
                          setValue(
                            'normalizationConfig.otfNormalizationConfig.active',
                            !isChecked,
                            { shouldValidate: true, shouldDirty: true } // Ensures validation and dirty tracking
                          );
                        }}
                      />
                    }
                    label="Use pre-normalized values (min_c, max_c)"
                  />
                )}
              />
            </Grid>
            {!watch('normalizationConfig.loadParamsFromDb') && (
              <Grid item>
                <NormalizationParamsForm
                  normalizationConfig={watch('normalizationConfig.otfNormalizationConfig.normParamsConfig') || {}}
                  setNormalizationConfig={onNormalizationConfigChange}
                />
              </Grid>
            )}
          </Grid>
          <Grid item>
            <Controller
              control={control}
              name="skipRunExistingArtifacts"
              render={({ field: { onChange } }) => (
                <FormControlLabel
                  control={
                    <Checkbox
                      defaultChecked
                      {...register('skipRunExistingArtifacts')}
                      onChange={(event) => {
                        onChange(event);
                      }}
                    />
                  }
                  label={
                    <Typography>
                      Skip Existing Artifacts{' '}
                      <Typography variant="caption">
                        (Skip run if artifact exists - if False, an artifact will be created even if already exists )
                      </Typography>
                    </Typography>
                  }
                />
              )}
            />
          </Grid>
          <Grid container item spacing={2}>
            <Grid item xs={4}>
              <Controller
                control={control}
                name="clearMlMachineType"
                render={({ field: { onChange, value } }) => (
                  <LabelledDropdown
                    label="clearMl Machine Type"
                    options={clearMlMachineTypeOptions}
                    value={value}
                    onOptionSelected={(optionValue) => {
                      onChange(optionValue);
                    }}
                    error={Boolean(errors['clearMlMachineType'])}
                    helperText={humanize(errors['clearMlMachineType']?.message)}
                    required
                  />
                )}
              />
            </Grid>
            <Grid item xs={4}>
              <Controller
                control={control}
                name="inferenceVmsLimit"
                render={({ field: { onChange } }) => (
                  <TextField
                    label="Inference Vms Limit"
                    type="number"
                    InputProps={{ inputProps: { max: 100, min: -1 } }}
                    {...register('inferenceVmsLimit')}
                    onChange={onChange}
                    placeholder="Type Here"
                    error={Boolean(errors['inferenceVmsLimit'])}
                    helperText={humanize(errors['inferenceVmsLimit']?.message)}
                  />
                )}
              />
            </Grid>
            <Grid item xs={4}>
              <Controller
                control={control}
                name="branch"
                render={({ field: { onChange } }) => (
                  <TextField
                    label="Branch"
                    {...register('branch')}
                    onChange={onChange}
                    placeholder="Type Here"
                    error={Boolean(errors['branch'])}
                    helperText={humanize(errors['branch']?.message)}
                    required
                  />
                )}
              />
            </Grid>
          </Grid>
        </Grid>
      ),
      onNextOrBackClick: () => {
        setIsAdvancedStepFailed();
      },
    },
  };

  const steps = map(runInferenceStepsOrder, (stepId) => stepsMap[stepId]);

  return (
    <JobWithRebuild jobId={jobId} onSelectedJobParamChange={onSelectedJobParamChange}>
      <PlatformStepper
        handleSubmit={handleSubmit(onSubmit)}
        steps={steps}
        setActiveStepForValidation={setActiveStep}
        isStepFailed={isStepFailed}
        handleSaveAsPreset={savePreset}
      />
    </JobWithRebuild>
  );
};

interface OnTheFlyNormalizationConfig {
  active: boolean;
  normParamsConfig?: NormalizationConfig | null;
  channelName?: string | null;
  histogramRoot?: string | null;
}

interface InferenceNormalizationConfig {
  active: boolean;
  loadParamsFromDb: boolean; // If true, loads the per_slide_params from the DB
  // Parameters to find normalization values on the fly. If null, use normalization from DB.
  otfNormalizationConfig?: OnTheFlyNormalizationConfig | null;
}

export interface IFormValues {
  jobName: string;
  jobDescription: string;
  hmName: string;
  slideStainType: string;
  model: string;
  tsmModel: string;
  modelType: string;
  dedupValue: number;
  roiMask: boolean;
  assignmentIds: string[];
  classNames: string[];
  clearMlMachineType: string;
  inferenceVmsLimit: number;
  branch: string;
  skipRunExistingArtifacts: boolean;
  useDynamicCellDetection: boolean;
  dynamicCellDetectionConfig: string;
  channelsToExtract?: string[];
  normalizationConfig?: InferenceNormalizationConfig;
}

const validationSchema = {
  [RunInferenceStep.LoadParams]: yup.object({}),
  [RunInferenceStep.NameAndDescription]: yup.object({
    jobName: yup.string(),
    jobDescription: yup.string(),
    hmName: yup.string().required(),
  }),
  [RunInferenceStep.SelectModel]: yup.object({
    slideStainType: yup.string().required('Stain Type is a required field'),
    model: yup.string().required(),
    modelType: yup.string().required(),
  }),
  [RunInferenceStep.CellModelConfig]: yup.object({
    dedupValue: yup.number().when('modelType', {
      is: modelTypeCell.value,
      then: yup.number().required('Dedup Value is a required field'),
    }),
    useDynamicCellDetection: yup.boolean(),
    dynamicCellDetectionConfig: yup.string(),
  }),
  [RunInferenceStep.SelectTsmModel]: yup.object({
    tsmModel: yup.string().when('modelType', {
      is: modelTypeDefect.value,
      then: yup.string().notRequired().nullable(),
      otherwise: yup.string().required(),
    }),
  }),
  [RunInferenceStep.RoiMask]: yup.object({
    roiMask: yup.boolean(),
    assignmentIds: yup.array().when('roiMask', {
      is: true,
      then: yup.array().required('Required when roiMask is checked'),
    }),
    classNames: yup.array().when('roiMask', {
      is: true,
      then: yup.array().required('Required when roiMask is checked'),
    }),
  }),
  [RunInferenceStep.Advanced]: yup.object({
    skipRunExistingArtifacts: yup.boolean(),
    clearMlMachineType: yup.string().required(),
    inferenceVmsLimit: yup.number().required(),
    branch: yup.string().required(),
    normalizationConfig: yup.object({
      active: yup.boolean().required(),
      loadParamsFromDb: yup.boolean().required(),
      otfNormalizationConfig: yup.object({
        active: yup.boolean().required(),
        normParamsConfig: normalizationSchema,
      }),
    }),
  }),
};
