import React from 'react';
import classNames from 'classnames/bind';
import { Controller, useFormContext } from 'react-hook-form';

import { Field } from 'components/Field/Field';
import { FloatInput } from 'components/NumberInput/NumberInput';
import { MultiInput } from 'components/MultiInput/MultiInput';
import { Radio, RadioOption } from 'components/Radio/Radio';

import {
  FIELD_LABEL_WIDTH,
  Step2FieldValues,
  areAllMultiInputValuesValid,
} from '../Step2';

import styles from './../MultiStepForm.module.scss';

const c = classNames.bind(styles);

export function NetScaleFactorInput() {
  const {
    control,
    formState: { errors },
    setValue,
    watch,
  } = useFormContext<Step2FieldValues>();

  const netScaleFactor = watch('inference_config.net_scale_factor');
  const shouldUseStdDeviation = netScaleFactor === null;

  return (
    <>
      <Field
        id="net_scale_factor"
        label="Net scale factor"
        labelWidth={FIELD_LABEL_WIDTH}
        labelVerticalAlign="flex-start"
        error={errors.inference_config?.net_scale_factor}
        required={!shouldUseStdDeviation}
      >
        <Controller
          name="inference_config.net_scale_factor"
          control={control}
          defaultValue={1.0}
          render={({ field: { value, onChange } }) => (
            <div className={c('input')}>
              <FloatInput
                id="net_scale_factor"
                value={value ?? ''}
                onChange={(value) => {
                  onChange(value);

                  if (Number.isFinite(value)) {
                    setValue('inference_config.normalization_std', undefined);
                  }
                }}
                step={0.01}
                decimalScale={20}
                min={0}
              />

              <Radio
                id="net_scale_factor_options"
                value={value}
                onChange={(value) => {
                  onChange(value);

                  if (Number.isFinite(value)) {
                    setValue('inference_config.normalization_std', undefined);
                  }
                }}
              >
                <RadioOption value={1}>1.0</RadioOption>
                <RadioOption value={1 / 256}>1/256</RadioOption>
                <RadioOption value={null}>
                  Use standard deviation values
                </RadioOption>
              </Radio>
            </div>
          )}
          rules={{
            min: {
              value: 0,
              message: 'Net scale factor must be at least 0.',
            },
            max: {
              value: 1,
              message: 'Net scale factor must be less than or equal to 1.',
            },
            validate(value) {
              if (value === null) {
                return true;
              }

              if (!Number.isFinite(value)) {
                return 'Please enter or select a net scale factor.';
              }
            },
          }}
        />
      </Field>

      {shouldUseStdDeviation && (
        <Field
          id="normalization_std_0"
          label="Standard deviation values"
          info="Per-channel values used to normalize the images before being passed to the model."
          labelWidth={FIELD_LABEL_WIDTH}
          labelVerticalAlign="flex-start"
          error={errors.inference_config?.normalization_std}
          required
        >
          <Controller
            name="inference_config.normalization_std"
            control={control}
            render={({ field: { value, onChange } }) => (
              <MultiInput
                id="normalization_std"
                length={3}
                type="number"
                value={value}
                defaultValue={[0, 0, 0]}
                onChange={onChange}
                decimalScale={4}
                step={0.01}
                min={0}
                max={1}
              />
            )}
            rules={{
              required:
                'Please enter standard deviation values for each channel.',
              validate: areAllMultiInputValuesValid,
            }}
          />
        </Field>
      )}
    </>
  );
}
