import React, { Key, ReactNode, useCallback, useEffect, useMemo, useState } from 'react'

import { Table } from 'antd'
import { ColumnsType } from 'antd/lib/table'

import GenericBlankStateMessage from 'components/BlankStates/GenericBlankStateMessage'
import ImgFallback from 'components/Img/ImgFallback'
import { PrismElementaryCube, PrismSearchIcon } from 'components/prismIcons'
import { PrismLoader } from 'components/PrismLoaders/PrismLoaders'
import { LabelButtonSeverity, PrismResultButton } from 'components/PrismResultButton/PrismResultButton'
import { useDefaultToolLabels, useQueryParams } from 'hooks'
import {
  Component,
  ConnectionStatus,
  Threshold,
  Tool,
  ToolLabel,
  ToolLabelSeverity,
  TrainingMetrics,
  TrainingReportTab,
  TrainingResultFlat,
} from 'types'
import {
  calculatePercentage,
  calculateTrainingResultPrediction,
  extractLabelIdFromCalculatedLabels,
  getToolLabelImagesToShow,
  sleep,
  sortByValueAndSeverity,
} from 'utils'

import { MemoLabelContainer } from './LabelContainer'
import { LiveBatchCarousels } from './LiveBatchCarousels'
import { ThresholdByRoutineParentId } from './TrainingReport'
import Styles from './TrainingReport.module.scss'

type RowId = `${string}_${string}`

interface RowData {
  key: RowId
  label: {
    image?: string[]
    value: string
    severity: ToolLabelSeverity
    id: string
  }
  images: number
  product: string
  productId: string
  accuracy: number
  trainingResultsIds?: string[]
}

type Props = {
  trainingReportTab: TrainingReportTab
  tool?: Tool
  showCurrentBatchCarousels: boolean
  connectionStatus: ConnectionStatus
  thresholdForLiveCarousels?: Threshold
  inspectionId?: string | null
  aoiParentId?: string
  backendThresholdByRoutine?: ThresholdByRoutineParentId
  modalLoaded: boolean
  allToolLabels: ToolLabel[] | undefined
  labelsUsed: ToolLabel[] | undefined
  components: Component[] | undefined
  threshold: Threshold | undefined
  insightsEnabled?: boolean
  loadingThreshold: boolean
  containerRef: React.RefObject<HTMLDivElement>
  trainingMetricsByLabelAndComponent: TrainingMetrics['byLabelId'] | undefined
  trainingResults: TrainingResultFlat[] | undefined
}

export const TrainingReportSamples = ({
  trainingReportTab,
  tool,
  showCurrentBatchCarousels,
  insightsEnabled,
  connectionStatus,
  thresholdForLiveCarousels,
  inspectionId,
  aoiParentId,
  backendThresholdByRoutine,
  modalLoaded,
  allToolLabels,
  components,
  labelsUsed,
  threshold,
  loadingThreshold,
  containerRef,
  trainingMetricsByLabelAndComponent,
  trainingResults,
}: Props) => {
  const [expandedRowsIds, setExpandedRowsIds] = useState<RowId[] | null>(null)
  const [carouselExpandedId, setCarouselExpandedId] = useState<RowId>()
  const [params] = useQueryParams<'training_result_id'>()
  const { training_result_id } = params
  const defaultLabels = useDefaultToolLabels()

  const handleOnExpandedRowsChange = (expandedRows: readonly Key[]) => {
    setExpandedRowsIds(expandedRows as RowId[])
    if (carouselExpandedId && !expandedRows.includes(carouselExpandedId)) setCarouselExpandedId(undefined)
  }

  const handleShowAll = useCallback(({ labelId, componentId }: { labelId: string; componentId: string }) => {
    const carouselKey: RowId = `${labelId}_${componentId}`
    setCarouselExpandedId(carouselKey)
  }, [])

  const sortedLabels = labelsUsed?.sort(sortByValueAndSeverity)

  const trainingResultsByLabelAndComponent = useMemo(() => {
    if (!allToolLabels) return
    return trainingResults?.reduce((trainingResultsByLabelAndComponent, trainingResult) => {
      const label = extractLabelIdFromCalculatedLabels(allToolLabels, trainingResult.calculated_labels)
      if (label) {
        const componentId = String(trainingResult.component_id)

        trainingResultsByLabelAndComponent[label] ??= {}
        trainingResultsByLabelAndComponent[label]![componentId] ??= []
        trainingResultsByLabelAndComponent[label]![componentId]!.push(trainingResult)
      }
      return trainingResultsByLabelAndComponent
    }, {} as { [labelId: string]: { [componentId: string]: TrainingResultFlat[] } })
  }, [allToolLabels, trainingResults])

  const { tableData, carousels } = useMemo(() => {
    if (!sortedLabels || !trainingResultsByLabelAndComponent || !trainingMetricsByLabelAndComponent)
      return { tableData: undefined, carousels: undefined }

    const { tableData, carousels } = sortedLabels.reduce(
      (tableDataAndCarousels, label) => {
        if (!(label.id in trainingResultsByLabelAndComponent)) return tableDataAndCarousels
        const componentsInTrainingResults = Object.keys(trainingResultsByLabelAndComponent[label.id] || {})
        const { tableData, carousels } = tableDataAndCarousels

        componentsInTrainingResults.forEach(componentId => {
          const metricsForLabelAndComponent = trainingMetricsByLabelAndComponent[label.id]?.[componentId]
          if (!metricsForLabelAndComponent) return

          const rowKey: RowId = `${label.id}_${componentId}`
          const imagesUsed = metricsForLabelAndComponent.total
          const correctPredictionsCount = metricsForLabelAndComponent.succesful

          const isCarouselExpanded = carouselExpandedId === rowKey
          const isRowExpanded = expandedRowsIds?.includes(rowKey)

          const filteredTrainingResults =
            threshold && tool && defaultLabels
              ? trainingResults?.reduce(
                  (acc, trainingResult) => {
                    // if (!threshold || !tool || !defaultLabels) return acc
                    const fromCurrentComponent =
                      componentId === 'null'
                        ? trainingResult.component_id === null
                        : trainingResult.component_id === componentId
                    if (trainingResult.calculated_labels.includes(label.id) && fromCurrentComponent) {
                      const calculatedLabelId = extractLabelIdFromCalculatedLabels(
                        allToolLabels || [],
                        trainingResult.calculated_labels,
                      )
                      acc.push({
                        ...trainingResult,
                        calculatedLabelId,
                        ...calculateTrainingResultPrediction({
                          trainingResult,
                          specName: tool?.specification_name,
                          threshold: threshold || {},
                          defaultLabels,
                          allToolLabels,
                          thresholdByRoutine: backendThresholdByRoutine,
                        }),
                      })
                    }
                    return acc
                  },
                  [] as (TrainingResultFlat & {
                    calculatedLabelId: string | undefined
                    isAboveThreshold: boolean | undefined
                    predictedLabel: ToolLabel | undefined
                  })[],
                )
              : undefined

          const filteredTrainingResultIds = filteredTrainingResults?.map(trainingResult => trainingResult.id)

          carousels[rowKey] = (
            <MemoLabelContainer
              key={rowKey}
              rowKey={rowKey}
              carouselKey={`${rowKey}-${trainingReportTab}`}
              outerContainerRef={containerRef}
              onShowAll={handleShowAll}
              expanded={isCarouselExpanded}
              label={label}
              loading={loadingThreshold}
              threshold={threshold}
              tool={tool}
              toolLabels={allToolLabels}
              componentId={componentId}
              trainingResults={filteredTrainingResults}
              modalLoaded={modalLoaded}
              thresholdByRoutine={backendThresholdByRoutine}
              insightsEnabled={insightsEnabled}
              isRowExpanded={!!isRowExpanded}
            />
          )

          const component = components?.find(component => component.id === componentId)
          tableData.push({
            key: rowKey,
            label: {
              image: getToolLabelImagesToShow(label),
              value: label.value,
              severity: label.severity,
              id: label.id,
            },
            product: component?.name || '--',
            images: imagesUsed,
            accuracy: calculatePercentage(correctPredictionsCount, imagesUsed),
            productId: component?.id || '',
            trainingResultsIds: filteredTrainingResultIds,
          })
        })

        return { tableData, carousels }
      },
      { tableData: [], carousels: {} } as {
        tableData: RowData[]
        carousels: { [key: RowId]: ReactNode }
      },
    )
    // Sort by products by default
    tableData?.sort((a: RowData, b: RowData) => a.product.localeCompare(b.product))

    return { tableData, carousels }
  }, [
    sortedLabels,
    trainingResultsByLabelAndComponent,
    trainingMetricsByLabelAndComponent,
    carouselExpandedId,
    expandedRowsIds,
    trainingResults,
    trainingReportTab,
    containerRef,
    handleShowAll,
    loadingThreshold,
    threshold,
    tool,
    allToolLabels,
    modalLoaded,
    backendThresholdByRoutine,
    insightsEnabled,
    components,
    defaultLabels,
  ])

  const isLoading = !tableData || !carousels

  useEffect(() => {
    if (expandedRowsIds !== null || isLoading || tableData.every(row => row.trainingResultsIds === undefined)) return

    const run = async () => {
      const initialExpadedRowIds: RowId[] = []

      let foundTableRow: RowData | undefined = undefined

      if (training_result_id) {
        foundTableRow = tableData.find(tableData => tableData.trainingResultsIds?.includes(training_result_id))
        if (foundTableRow) {
          initialExpadedRowIds.push(foundTableRow.key)
        }
      }

      setExpandedRowsIds(initialExpadedRowIds)
      if (foundTableRow) {
        // Wait just a little bit for the expand animation to run
        await sleep(300)
        const tableRow = document.getElementById(foundTableRow.key)
        tableRow?.scrollIntoView(true)
      }
    }

    run()
  }, [expandedRowsIds, isLoading, tableData, training_result_id])

  return (
    <section
      className={`${Styles.carouselsWrapper} ${!tableData || tableData?.length === 0 ? Styles.containerHeight : ''}`}
    >
      {tool &&
        showCurrentBatchCarousels &&
        connectionStatus !== 'offline' &&
        thresholdForLiveCarousels !== undefined && (
          <LiveBatchCarousels
            inspectionId={inspectionId!}
            threshold={thresholdForLiveCarousels}
            modalLoaded={modalLoaded}
            tool={tool}
            aoiParentId={aoiParentId}
            thresholdByRoutine={backendThresholdByRoutine}
          />
        )}

      {!showCurrentBatchCarousels && connectionStatus !== 'offline' && (
        <Table
          dataSource={tableData}
          columns={columns}
          pagination={false}
          onRow={rowData => {
            return {
              id: rowData.key,
              'data-test': 'training-report-samples-row',
              'data-testid': `training-report-row-${rowData.label.severity}-${rowData.label.value}`,
            }
          }}
          expandable={{
            expandedRowRender: record => carousels?.[record.key],
            expandedRowKeys: expandedRowsIds || [],
            onExpandedRowsChange: handleOnExpandedRowsChange,
            expandRowByClick: true,
            expandIconColumnIndex: -1,
          }}
          className={Styles.tableContainer}
          components={{
            body: { row: 'div' },
          }}
          rowClassName={record => {
            let rowClassName = `${Styles.tableRow} `
            if (expandedRowsIds?.includes(record.key)) rowClassName += Styles.rowIsActive
            return rowClassName
          }}
          loading={{
            spinning: isLoading,
            indicator: <PrismLoader />,
          }}
          locale={{
            emptyText: !isLoading ? (
              <GenericBlankStateMessage
                header={<PrismSearchIcon />}
                description="No images in this set"
                className={Styles.emptyStateWrapper}
              />
            ) : (
              <div></div>
            ),
          }}
        />
      )}
    </section>
  )
}

const columns: ColumnsType<RowData> = [
  {
    title: 'Label',
    dataIndex: 'label',
    key: 'label',

    render: (labelData: RowData['label']) => {
      return <ImageAndPrismResult image={labelData?.image?.[0]} value={labelData.value} severity={labelData.severity} />
    },
    sorter: (a: RowData, b: RowData) => a.label.value.localeCompare(b.label.value),
    width: '30%',
  },
  {
    title: 'Product',
    dataIndex: 'product',
    key: 'product',
    sorter: (a: RowData, b: RowData) => a.product.localeCompare(b.product),
  },
  {
    title: 'Images',
    dataIndex: 'images',
    key: 'images',
    sorter: (a: RowData, b: RowData) => a.images - b.images,
    onCell: () => {
      return { 'data-test': 'training-report-images-cell' } as React.HtmlHTMLAttributes<RowData>
    },
  },
  {
    title: 'Accuracy',
    dataIndex: 'accuracy',
    key: 'accuracy',
    sorter: (a: RowData, b: RowData) => a.accuracy - b.accuracy,
    render: (accuracyData: RowData['accuracy']) => `${accuracyData.toFixed(1)}%`,
    onCell: () => {
      return { 'data-test': 'training-report-accuracy-cell' } as React.HtmlHTMLAttributes<RowData>
    },
  },
]

/**
 * Renders the first column containing and image with the prismResult (severity icon and label name)
 */
const ImageAndPrismResult = ({
  image,
  value,
  severity,
}: {
  image?: string
  value: string
  severity: LabelButtonSeverity
}) => {
  return (
    <div className={Styles.labelColumn}>
      <figure className={Styles.labelImage}>
        {image && <ImgFallback src={image} loaderType="skeleton" className={Styles.toolColumnImage} />}
        {!image && <PrismElementaryCube />}
      </figure>

      <PrismResultButton
        data-testid="training-report-predicted-label"
        type="noFill"
        size="small"
        value={value}
        severity={severity}
      />
    </div>
  )
}
