import { CohortIndex } from '@groundwater/shared-ui';
import {
  Paper,
  Table as MuiTable,
  TableBody,
  TableCell,
  TableContainer,
  TableHead,
  TableRow,
} from '@mui/material';
import { isNil } from 'lodash-es';
import { FixedTableCell } from './RetentionTableCells/CohortPeriodTableCell';
import { DefaultCell } from './RetentionTableCells/DefaultCell';
import { format } from 'd3-format';
import { CohortedRetentionVisualizationProps } from '../features/CohortedRetention/types';
import { notEmpty } from '@groundwater/shared-util';
import { scaleSequential } from 'd3-scale';
import { interpolateBlues } from 'd3-scale-chromatic';

export type RetentionHeatmapProps = CohortedRetentionVisualizationProps & {
  period_0_label: string;
  y_zoom: boolean;
};

const numLegendLabels = 10;

const fontColor = (bgColor: string) => {
  // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
  const backgroundColorLightness: number = bgColor
    .match(/([0-9]+)/g)!
    .map((str) => parseInt(str, 10))
    .reduce((sum, val) => val + sum, 0);

  /**
   * The foreground / font color that inverts from white to black depending on the bg color
   * being within the range corresponding to lighter backgrounds
   */
  const fontColor = backgroundColorLightness > 500 ? 'black' : 'white';
  return fontColor;
};

export const RetentionHeatmap: React.FC<RetentionHeatmapProps> = (props) => {
  const { min, max } = props.periods
    .map((period) => period.data.filter(notEmpty))
    .filter(notEmpty)
    .reduce(
      ({ min, max }, data) => ({
        min: Math.min(...data, min),
        max: Math.max(...data, max),
      }),
      props.y_zoom ? { min: Infinity, max: -Infinity } : { min: 0, max: 1 }
    );

  const interpolator = scaleSequential()
    .interpolator(interpolateBlues)
    .domain([min, max]);

  return (
    <div style={{ display: 'flex' }}>
      <div style={{ flexGrow: 1 }}>
        <Table
          {...{
            ...props,
            interpolator,
          }}
        />
      </div>
      <div style={{ width: 150 }}>
        <Legend {...{ interpolator, min, max }} />
      </div>
    </div>
  );
};

const Table: React.FC<
  RetentionHeatmapProps & {
    interpolator: (num: number) => string;
  }
> = ({ interpolator, cohorts, period_0, period_0_label, periods, xaxis }) => {
  // To label the HTML table's columns, we pluck the list of ids from the series,
  // since each series represents one column in the HTML table
  const datesOrPeriods: string[] = periods.map((period) => period.id);

  const cellWidth = xaxis === CohortIndex.Cohort ? 50 : 110;
  return (
    <TableContainer component={Paper}>
      <MuiTable
        size="small"
        sx={{ width: 'auto' }}
        aria-label="cohorted-retention-table"
      >
        <TableHead>
          <TableRow>
            <FixedTableCell minWidth={120} left={0}></FixedTableCell>
            <FixedTableCell align="right" minWidth={120} left={120}>
              {period_0_label}
            </FixedTableCell>

            {datesOrPeriods.map((period) => {
              return (
                <TableCell key={period} align="center">
                  {period}
                </TableCell>
              );
            })}
          </TableRow>
        </TableHead>
        <TableBody>
          {cohorts.map((cohort, cohort_index) => {
            return (
              <TableRow
                key={cohort}
                sx={{ '&:last-child td, &:last-child th': { border: 0 } }}
              >
                <FixedTableCell minWidth={120} left={0}>
                  {cohort}
                </FixedTableCell>

                <FixedTableCell align="right" minWidth={170} left={120}>
                  {period_0[cohort_index]}
                </FixedTableCell>

                {datesOrPeriods.map((dateOrPeriod, dateOrPeriodIndex) => {
                  /** Find the period for this column */
                  const period = periods.find(
                    (period) => period.id === dateOrPeriod
                  );
                  if (!period) {
                    throw new Error(`Missing period for ${dateOrPeriod}`);
                  }

                  const value: null | number =
                    period.data[cohort_index] ?? null;

                  if (isNil(value)) {
                    return (
                      <DefaultCell
                        key={`${dateOrPeriodIndex}-${dateOrPeriod}`}
                        minWidth={cellWidth}
                      />
                    );
                  }

                  const formattedPercent: string = format('.0%')(Number(value));

                  /**
                   * Background of the cell represents the value respective to the max
                   */
                  const backgroundColor = interpolator(value);

                  return (
                    <DefaultCell
                      key={`${dateOrPeriodIndex}-${dateOrPeriod}`}
                      minWidth={cellWidth}
                      fontColor={fontColor(backgroundColor)}
                      backgroundColor={backgroundColor}
                      align="center"
                    >
                      {formattedPercent}
                    </DefaultCell>
                  );
                })}
              </TableRow>
            );
          })}
        </TableBody>
      </MuiTable>
    </TableContainer>
  );
};

const Legend: React.FC<{
  interpolator: (num: number) => string;
  min: number;
  max: number;
}> = ({ interpolator, min, max }) => {
  return (
    <>
      Color Key:
      <div
        style={{
          display: 'flex',
          flexDirection: 'column',
          width: 50,
        }}
      >
        {new Array(numLegendLabels + 1)
          .fill(null)
          .map((_, i) => i / numLegendLabels)
          .map((i) => {
            const num = min + (max - min) * i;

            const bgColor = interpolator(num);
            const color = fontColor(bgColor);
            return (
              <div
                key={i}
                style={{
                  backgroundColor: bgColor,
                  flexGrow: 1,
                  color,
                  textAlign: 'center',
                }}
              >
                {format('.0%')(num)}
              </div>
            );
          })}
      </div>
    </>
  );
};
