import React, { useEffect, useMemo, useState } from 'react'

import { History } from 'history'
import { uniq } from 'lodash'
import qs from 'qs'
import { useDispatch } from 'react-redux'
import { useQuery } from 'react-redux-query'
import { useHistory } from 'react-router-dom'

import { getterKeys, query, service } from 'api'
import { Button, PrismButtonProps } from 'components/Button/Button'
import ImgFallback from 'components/Img/ImgFallback'
import PrismCheckbox from 'components/PrismCheckbox/PrismCheckbox'
import { PrismElementaryCube, PrismSharedToolIcon } from 'components/prismIcons'
import { error } from 'components/PrismMessage/PrismMessage'
import { Modal, modal, ModalHeaderWithIcon } from 'components/PrismModal/PrismModal'
import PrismOverflowTooltip from 'components/PrismOverflowTooltip/PrismOverflowTooltip'
import { PrismResultButton } from 'components/PrismResultButton/PrismResultButton'
import { useAllToolLabels, useQueryParams } from 'hooks'
import { CaptionText } from 'pages/RoutineOverview/RoutineSettings/ItemCreation'
import { Component, Experiment, ToolLabel, ToolResultCount, ToolSpecificationName } from 'types'
import {
  findSingleToolLabelFromPartialData,
  getDisplaySeverity,
  getExperimentState,
  getLabelName,
  getToolLabelImagesToShow,
  sortByValueAndSeverity,
} from 'utils'
import {
  MINIMUM_LABELS_COUNT_FOR_DEFECT_AND_MATCH_TRAINING,
  MINIMUM_NORMAL_IMAGES_FOR_GAR_AND_ANOMALY_TRAINING,
} from 'utils/constants'
import { GOOD_NORMAL_LABEL, TEST_SET_LABEL } from 'utils/labels'

import Styles from './TrainButton.module.scss'

const LABEL_MINIMUM_ERROR_MESSAGE = 'You need two labels in order to train'
const MISSING_NORMAL_LABEL_ERROR_MESSAGE = `You need the Normal label with ${MINIMUM_NORMAL_IMAGES_FOR_GAR_AND_ANOMALY_TRAINING} images in order to train`
const MISSING_NORMAL_AND_ANOTHER_LABEL_ERROR_MESSAGE = 'You need the Normal label and one other label in order to train'

/**
 * Renders the button that handles the training Kickoff and cancel.
 *
 * @param readOnly - Whether read only mode is active.
 * @param routine - The current Routine.
 * @param tool - The tool to be trained.
 * @param mostRecentModel - The most recent model to be canceled if training is in progress.
 */
const TrainButton = ({
  readOnly,
  latestExperiment,
  toolParentId,
  buttonType,
}: {
  readOnly: boolean
  latestExperiment: Experiment | undefined
  toolParentId: string
  buttonType?: PrismButtonProps['type']
}) => {
  const dispatch = useDispatch()
  const history = useHistory()
  const [params] = useQueryParams()

  const trainingMostRecentModel = latestExperiment && getExperimentState(latestExperiment.state) === 'in_progress'

  const refetchToolParent = async () => {
    await query(getterKeys.toolParent(toolParentId), () => service.getToolParent(toolParentId), {
      dispatch,
    })
  }

  const cancelTraining = async (experimentId: string) => {
    const cancelRes = await service.patchExperiment(experimentId, { state: 'canceled' })

    if (cancelRes.type !== 'success') {
      error({ title: 'Something went wrong cancelling the training.' })
      return
    }

    await refetchToolParent()
  }

  const handleTrainButtonClick = async () => {
    if (latestExperiment?.id && trainingMostRecentModel) {
      // We need to cancel the training of the last tool's experiment
      modal.confirm({
        id: 'cancel-training-confirmation',
        header: <ModalHeaderWithIcon title="Cancel Training?" type="warning" />,
        danger: true,
        content: "Are you sure? You'll lose your training progress.",
        okText: 'CANCEL TRAINING',
        cancelText: 'CONTINUE',
        onOk: async close => {
          if (!latestExperiment) return
          await cancelTraining(latestExperiment.id)
          close()
        },
      })
    } else {
      handleShowTrainingModal({ history, params })
    }
  }

  return (
    <div className={Styles.trainBtnWrapper}>
      <Button
        data-testid="tool-info-card-train"
        size="small"
        type={buttonType}
        onClick={handleTrainButtonClick}
        className={trainingMostRecentModel ? Styles.trainBtn : ''}
        disabled={readOnly}
      >
        {trainingMostRecentModel && 'Cancel'} Train
      </Button>
    </div>
  )
}

export const showTrainingModal = ({
  history,
  params,
}: {
  history: History
  params: { [x: string]: string | undefined }
}) => {
  history.replace({
    pathname: history.location.pathname,
    search: qs.stringify({ ...params, showTrainingModal: true }),
  })
}

const TrainingModalListHeader = ({
  total,
  count,
  title,
  hideCount,
}: {
  total?: number
  count?: number
  title: string
  hideCount?: boolean
}) => {
  return (
    <div className={Styles.listHeader}>
      <span>
        {title} {!hideCount && `(${count} of ${total})`}
      </span>
      <span>IMAGES</span>
    </div>
  )
}

export const handleShowTrainingModal = ({
  history,
  params,
}: {
  history: History
  params: { [x: string]: string | undefined }
}) => {
  history.replace({
    pathname: history.location.pathname,
    search: qs.stringify({ ...params, showTrainingModal: true }),
  })
}

export const TrainingModal = ({
  toolParentId,
  toolSpecificationName,
  toolParentName,
  toolIsShared,
  totalCountsByLabelId,
  countsByComponentId,
  toolResultsCounts,
  routineId,
}: {
  toolParentId: string
  toolSpecificationName: ToolSpecificationName | undefined
  toolParentName: string | undefined
  toolIsShared: boolean | undefined
  totalCountsByLabelId: { [labelId: string]: number } | undefined
  countsByComponentId: { [componentId: string]: number } | undefined
  toolResultsCounts: ToolResultCount['results'] | undefined
  routineId: string | undefined
}) => {
  const dispatch = useDispatch()
  const history = useHistory()
  const [params] = useQueryParams()

  const [selectedLabelIds, setSelectedLabelIds] = useState<Set<string> | null>(null)
  const [selectedComponentIds, setSelectedComponentIds] = useState<Set<string>>()

  const recipeRoutines = useQuery(
    routineId ? undefined : getterKeys.toolParentRecipeRoutines(`train-modal-${toolParentId}`),
    () => service.getRecipeRoutines({ tool_parent_id: toolParentId, is_working_version: true, page_size: 1 }),
  ).data?.data.results

  const routineIdToUse = useMemo(() => {
    if (routineId) return routineId

    return recipeRoutines?.[0]?.routine.id
  }, [recipeRoutines, routineId])

  const componentIds = useMemo(() => {
    if (!countsByComponentId) return
    return uniq(Object.keys(countsByComponentId))
  }, [countsByComponentId])

  const components = useQuery(
    getterKeys.components(`training-modal-${toolParentId}`),
    () => service.getComponents({ id__in: componentIds?.join() }),
    { refetchKey: componentIds?.sort().join() },
  ).data?.data.results

  // If some componentId value is null, we don't allow filtering by component when
  // selecting Tool Results for training
  const componentFilteringEnabled = !Object.keys(countsByComponentId || {}).includes('null')

  const componentsForTraining = useMemo(() => {
    if (!components || !countsByComponentId) return

    return components.filter(component => !!countsByComponentId?.[component.id])
  }, [countsByComponentId, components])

  const { allToolLabels, defaultLabels } = useAllToolLabels({
    ignoreUnusedLabels: true,
    ignoreDerivativeLabels: true,
    ignoreNonTrainingLabels: true,
  })

  const testSetLabel = useMemo(() => {
    return findSingleToolLabelFromPartialData(defaultLabels, TEST_SET_LABEL)
  }, [defaultLabels])

  const toolLabelsForTraining = useMemo(() => {
    if (!allToolLabels || !totalCountsByLabelId) return
    const labelsToReturn: ToolLabel[] = []
    Object.entries(totalCountsByLabelId).forEach(([labelId, count]) => {
      if (!count) return

      const foundLabel = allToolLabels.find(label => label.id === labelId)
      if (foundLabel) labelsToReturn.push(foundLabel)
    })

    labelsToReturn.sort(sortByValueAndSeverity)

    return labelsToReturn
  }, [allToolLabels, totalCountsByLabelId])

  useEffect(() => {
    if (!toolLabelsForTraining || selectedLabelIds !== null) return

    setSelectedLabelIds(new Set(toolLabelsForTraining.map(toolLabel => toolLabel.id)))
  }, [selectedLabelIds, toolLabelsForTraining])

  useEffect(() => {
    if (!componentsForTraining) return

    setSelectedComponentIds(new Set(componentsForTraining.map(component => component.id)))
  }, [componentsForTraining])

  const countsByLabelIdFilteredBySelectedComponents = useMemo(() => {
    if (!toolResultsCounts || !selectedComponentIds || !toolLabelsForTraining) return
    const countsToReturn: { [labelId: string]: number } = {}
    toolResultsCounts.forEach(countObj => {
      const componentId = countObj.component_id
      if (componentFilteringEnabled && componentId && !selectedComponentIds.has(componentId)) return

      const labelId = countObj.active_user_label_set__tool_labels__id

      if (!toolLabelsForTraining.find(toolLabel => toolLabel.id === labelId)) return

      countsToReturn[labelId] ??= 0
      countsToReturn[labelId] += countObj.count
    })

    return countsToReturn
  }, [toolResultsCounts, selectedComponentIds, toolLabelsForTraining, componentFilteringEnabled])

  const countsByComponentIdFilteredBySelectedLabels = useMemo(() => {
    if (!toolResultsCounts || !selectedLabelIds) return
    const countsToReturn: { [labelId: string]: number } = {}
    toolResultsCounts.forEach(countObj => {
      const componentId = countObj.component_id
      if (!componentId) return

      const labelId = countObj.active_user_label_set__tool_labels__id

      if (!selectedLabelIds?.has(labelId)) return

      countsToReturn[componentId] ??= 0
      countsToReturn[componentId] += countObj.count
    })

    return countsToReturn
  }, [toolResultsCounts, selectedLabelIds])

  // This effect is in charge of deselecting ToolLabels whose count is 0 after changing the selected Products
  useEffect(() => {
    if (!toolLabelsForTraining || !countsByLabelIdFilteredBySelectedComponents || !selectedLabelIds?.size) return

    toolLabelsForTraining.forEach(toolLabel => {
      if (selectedLabelIds.has(toolLabel.id) && !countsByLabelIdFilteredBySelectedComponents[toolLabel.id]) {
        handleLabelCheckboxChange(false, toolLabel.id)
      }
    })
  }, [countsByLabelIdFilteredBySelectedComponents, selectedLabelIds, toolLabelsForTraining])

  const totalImagesForTraining = useMemo(() => {
    if (!selectedLabelIds) return 0
    let total = 0
    Array.from(selectedLabelIds).forEach(labelId => {
      const labelCount = countsByLabelIdFilteredBySelectedComponents?.[labelId] || 0
      total += labelCount
    })

    return total
  }, [selectedLabelIds, countsByLabelIdFilteredBySelectedComponents])

  const { readyToTrain, errorMessage } = useMemo(() => {
    // For match tool, we need at least 2 labels
    if (toolSpecificationName === 'match-classifier') {
      const minimumLabelsSelected = (selectedLabelIds?.size || 0) >= MINIMUM_LABELS_COUNT_FOR_DEFECT_AND_MATCH_TRAINING

      if (minimumLabelsSelected) return { readyToTrain: true, errorMessage: undefined }

      return { readyToTrain: false, errorMessage: LABEL_MINIMUM_ERROR_MESSAGE }
    }

    const normalLabelId = findSingleToolLabelFromPartialData(toolLabelsForTraining, GOOD_NORMAL_LABEL)?.id
    const normalLabelSelected = normalLabelId && selectedLabelIds?.has(normalLabelId)

    // For defect we need the Normal and at least one other label
    if (toolSpecificationName === 'classifier') {
      if (!normalLabelSelected || (selectedLabelIds?.size || 0) < MINIMUM_LABELS_COUNT_FOR_DEFECT_AND_MATCH_TRAINING) {
        return { readyToTrain: false, errorMessage: MISSING_NORMAL_AND_ANOTHER_LABEL_ERROR_MESSAGE }
      }

      return { readyToTrain: true, errorMessage: undefined }
    }

    // For GAR and Anomaly we only need the Normal label with 5 images
    if (normalLabelSelected) {
      const normalLabelCount = totalCountsByLabelId?.[normalLabelId] || 0
      if (normalLabelCount >= MINIMUM_NORMAL_IMAGES_FOR_GAR_AND_ANOMALY_TRAINING) {
        return { readyToTrain: true, errorMessage: undefined }
      }
    }
    return { readyToTrain: false, errorMessage: MISSING_NORMAL_LABEL_ERROR_MESSAGE }
  }, [selectedLabelIds, toolSpecificationName, toolLabelsForTraining, totalCountsByLabelId])

  const handleLabelCheckboxChange = (checked: boolean, id: string) => {
    setSelectedLabelIds(prev => {
      const updated = new Set(prev)

      if (checked) updated.add(id)
      else updated.delete(id)

      return updated
    })
  }

  const handleProductCheckboxChange = (checked: boolean, id: string) => {
    setSelectedComponentIds(prev => {
      const updated = new Set(prev)

      if (checked) updated.add(id)
      else updated.delete(id)

      return updated
    })
  }

  const refetchToolParent = async () => {
    await query(getterKeys.toolParent(toolParentId), () => service.getToolParent(toolParentId), {
      dispatch,
    })
  }

  const closeTrainingModal = () => {
    history.replace({
      pathname: history.location.pathname,
      search: qs.stringify({ ...params, showTrainingModal: undefined }),
    })
  }

  const handleTrainingKickoff = async (
    labelIds: string[],
    componentIds: string[],
    componentFilteringEnabled: boolean,
  ) => {
    if (!testSetLabel) return

    const res = await service.trainRoutine({
      routine_id: routineIdToUse,
      tool_parent_ids: [toolParentId],
      // We need to send the TestSet ToolLabel id for training
      user_label_ids: [...labelIds, testSetLabel.id],
      component_ids: componentFilteringEnabled ? componentIds : undefined,
    })

    if (res.type !== 'success') {
      error({ title: 'Something went wrong initiating the training job.' })
      return
    }

    await refetchToolParent()

    closeTrainingModal()
  }

  return (
    <Modal
      id="training-modal"
      size="largeSimpleForm"
      onClose={() => closeTrainingModal()}
      header={
        <div className={Styles.modalTitle}>
          Train {toolParentName} {toolIsShared && <PrismSharedToolIcon className={Styles.sharedIcon} />}{' '}
        </div>
      }
      okText={`Train on ${totalImagesForTraining}`}
      onOk={async () => {
        if (!selectedLabelIds || !selectedComponentIds) return

        // Only send component Ids with ToolResults based on the selected ToolLabels
        const componentIds = Array.from(selectedComponentIds).filter(
          componentId => !!countsByComponentIdFilteredBySelectedLabels?.[componentId],
        )

        await handleTrainingKickoff(Array.from(selectedLabelIds), componentIds, componentFilteringEnabled)
      }}
      disableSave={!readyToTrain}
      okToolTipProps={readyToTrain ? undefined : { title: errorMessage }}
      modalBodyClassName={Styles.trainModalBody}
      modalFooterClassName={Styles.trainModalFooter}
    >
      <CaptionText
        text="Select images from this tool that you’d like to train with"
        captionClassName={Styles.caption}
        iconClassName={Styles.captionIcon}
        className={Styles.captionWrapper}
      />

      <section className={Styles.trainSectionContainer}>
        <TrainingModalListHeader title="LABELS" count={selectedLabelIds?.size} total={toolLabelsForTraining?.length} />

        {toolLabelsForTraining?.map(toolLabel => {
          const labelImage = toolLabel.kind !== 'default' && getToolLabelImagesToShow(toolLabel).find(img => !!img)
          const labelCount = countsByLabelIdFilteredBySelectedComponents?.[toolLabel.id]
          const count = labelCount || 0
          return (
            <TrainingModalListItem
              key={toolLabel.id}
              id={toolLabel.id}
              title={
                <PrismResultButton
                  severity={getDisplaySeverity(toolLabel)}
                  value={getLabelName(toolLabel)}
                  type="noFill"
                  size="small"
                  className={Styles.trainModalSeverity}
                />
              }
              image={
                labelImage ? (
                  <ImgFallback src={labelImage} className={Styles.toolLabelImage} loaderType="skeleton" />
                ) : undefined
              }
              count={count}
              checked={selectedLabelIds?.has(toolLabel.id)}
              onChange={handleLabelCheckboxChange}
              disabled={!labelCount}
            />
          )
        })}
      </section>

      <section className={Styles.trainSectionContainer}>
        <TrainingModalListHeader
          title="PRODUCTS"
          count={selectedComponentIds?.size}
          total={componentsForTraining?.length}
          hideCount={!componentFilteringEnabled}
        />

        {componentFilteringEnabled &&
          componentsForTraining?.map(component => {
            const productImage = getProductImage(component)
            const productCounts = countsByComponentIdFilteredBySelectedLabels?.[component.id]
            const count = productCounts || 0
            return (
              <TrainingModalListItem
                key={component.id}
                id={component.id}
                title={<PrismOverflowTooltip className={Styles.productName} content={component.name} />}
                image={
                  productImage ? (
                    <ImgFallback src={productImage} className={Styles.toolLabelImage} loaderType="skeleton" />
                  ) : undefined
                }
                count={count}
                checked={selectedComponentIds?.has(component.id)}
                onChange={handleProductCheckboxChange}
              />
            )
          })}

        {!componentFilteringEnabled && (
          <CaptionText
            type="info"
            text="This tool does not support filtering by product"
            captionClassName={Styles.caption}
            iconClassName={Styles.captionIcon}
            className={Styles.captionWrapper}
          />
        )}
      </section>
    </Modal>
  )
}

const TrainingModalListItem = ({
  title,
  image,
  count,
  id,
  checked,
  onChange,
  disabled,
}: {
  title: React.ReactNode
  image?: React.ReactNode
  count?: React.ReactNode
  id: string
  checked?: boolean
  onChange: (checked: boolean, id: string) => void
  disabled?: boolean
}) => {
  return (
    <PrismCheckbox
      checked={checked}
      disabled={disabled}
      onChange={e => onChange(e.target.checked, id)}
      className={Styles.labelListItem}
      labelClassName={Styles.labelWrapper}
      label={
        <>
          <figure className={Styles.labelImageContainer}>{image ? image : <PrismElementaryCube />}</figure>

          {title}
          <div className={Styles.labelImageCount}>{count}</div>
        </>
      }
    />
  )
}

const getProductImage = (product: Component) => {
  if (product.image) return product.image

  const firstRecipeParent = product.recipe_parents[0]
  const productThumbnail = firstRecipeParent?.fallback_images.find(image => !!image.image_thumbnail)?.image_thumbnail

  const productImage = productThumbnail || firstRecipeParent?.fallback_images.find(image => !!image.image)?.image

  return productImage
}

export default TrainButton
