import React, { type ReactNode, useMemo } from 'react';

import { Tooltip } from '@mui/material';
import {
  DataGridPremium,
  GRID_TREE_DATA_GROUPING_FIELD,
  type GridColDef,
} from '@mui/x-data-grid-premium';
import { type GridApiPremium } from '@mui/x-data-grid-premium/models/gridApiPremium';
import { format } from 'date-fns';
import _, { isNumber } from 'lodash';
import numbro from 'numbro';

import { ModelCompareTooltip } from '../common/dataTableFormatter';
import cn from '../utils/cn';

type DataGridPremiumTableProps = {
  apiRef: React.MutableRefObject<GridApiPremium>;
  columns: {
    name: string;
    label: string;
    is_pinned: boolean;
    data_type: 'number' | 'string' | 'date';
    styles: Record<string, Record<string, boolean | number | string | any>>[];
  }[];
  treeData?: {
    headerName: string;
    groupColumns: string[];
  };
  isLoading: boolean;
  placeholder?: string;
  data: Record<string, (string | number)[]> | undefined;
  base?: {
    data: Record<string, (string | number)[]> | undefined;
    mode: 'compare' | 'vector';
  };
};

const CompareCell = ({
  value,
  baseValue,
  numberFormat,
  isPinned,
}: {
  value: number;
  baseValue: number;
  isPinned: boolean;
  numberFormat: Record<string, any> | undefined;
}) => {
  let result;
  let color;

  if (!baseValue || isPinned) {
    result = value;
  } else if (Number(value) === Number(baseValue)) {
    result = '-';
  } else if (Number(value) > Number(baseValue)) {
    result = `(${Number(value) - Number(baseValue)})`;
    color = '#ff0000';
  } else if (Number(baseValue) > Number(value)) {
    result = Number(baseValue) - Number(value);
    color = 'rgb(46 161 46)';
  }

  return (
    <ModelCompareTooltip
      numberA={baseValue as number}
      numberB={value as number}
    >
      <span
        style={{
          color,
        }}
      >
        {isNumber(result) ? numbro(result).format(numberFormat) : result}
      </span>
    </ModelCompareTooltip>
  );
};

const VectorCell = ({
  baseValue,
  value,
  numberFormat,
}: {
  value: number;
  baseValue: number;
  numberFormat: Record<string, any> | undefined;
}) => (
  <Tooltip
    disableFocusListener
    title={
      <span className="text-[0.875rem] w-full">
        Base value: {numbro(baseValue).format(numberFormat ?? {})}
      </span>
    }
  >
    <div
      style={
        value === baseValue
          ? { color: '#000000' }
          : value > baseValue
            ? { color: 'rgb(46 161 46)' }
            : { color: '#ff0000' }
      }
    >
      <span>{numbro(value).format(numberFormat ?? {})}</span>
    </div>
  </Tooltip>
);

const FormatRow = ({
  isPinned,
  children,
  data_type,
}: {
  isPinned: boolean;
  children: ReactNode;
  data_type: 'number' | 'string' | 'date';
}) => {
  let result = children;

  if (data_type === 'date' && children) {
    result = format(new Date(Number(children)).toLocaleDateString(), 'Y/MM/dd');
  }

  if (isPinned === false && (result === 0 || result === '0')) {
    result = '-';
  }

  return (
    <div
      className={cn('items-center justify-end flex w-full', {
        'justify-start': isPinned,
      })}
    >
      {result}
    </div>
  );
};

const DataGridPremiumTable = ({
  columns,
  data,
  apiRef,
  isLoading,
  base,
  treeData,
  placeholder,
}: DataGridPremiumTableProps) => {
  const cols = useMemo(
    () =>
      columns
        .filter(({ name }) => !treeData?.groupColumns.includes(name))
        .map(({ label, name, is_pinned, styles, data_type }) => {
          const mergedStyles = styles
            ? {
                numfmt: styles[styles?.length - 1]?.numfmt ?? {},
                row_filters: [...styles.map((s) => s.row_filters).flat()],
              }
            : {};

          const filters =
            styles
              ?.map((i) =>
                i.row_filters.map((r: any) => ({
                  ...r,
                  numfmt: i.numfmt,
                }))
              )
              .flat() ?? [];

          const rowFilters = mergedStyles.row_filters;

          let numberFormat = {} as Record<string, any> | undefined;

          const baseMode = base?.mode;

          return {
            field: name,
            headerName: label,
            flex: 100,
            minWidth: is_pinned ? 200 : 140,
            align: 'right',
            headerAlign: is_pinned ? 'left' : 'right',
            hideable: false,
            sortable: false,
            pinnable: is_pinned,
            filterable: false,
            style: {
              backgroundColor: 'black',
              fontWeight: 'bold',
            },
            disableColumnMenu: true,
            renderCell: (params) => {
              const id = params.id;
              const field = params.field;
              const value = params.value;

              const baseValue = base?.data?.[field]?.[id as number];

              if (!rowFilters || rowFilters?.length === 0) {
                numberFormat = mergedStyles.numfmt;
              } else {
                const groupingKey = params.row.path?.[0];

                const rowFilter =
                  filters.find((r) => r.value === groupingKey) ??
                  filters.find((r) => params.row[r.field] === r.value);

                numberFormat = rowFilter?.numfmt;
              }

              if (!isNumber(value) && !isNumber(baseValue)) {
                return (
                  <FormatRow isPinned={is_pinned} data_type={data_type}>
                    {value}
                  </FormatRow>
                );
              }

              if (!baseValue) {
                return (
                  <FormatRow isPinned={is_pinned} data_type={data_type}>
                    {numbro(value).format(numberFormat)}
                  </FormatRow>
                );
              }

              switch (baseMode) {
                case 'compare':
                  return (
                    <FormatRow isPinned={is_pinned} data_type={data_type}>
                      <CompareCell
                        value={value}
                        isPinned={is_pinned}
                        baseValue={baseValue as number}
                        numberFormat={numberFormat}
                      />
                    </FormatRow>
                  );
                case 'vector':
                  return (
                    <FormatRow isPinned={is_pinned} data_type={data_type}>
                      <VectorCell
                        value={value}
                        baseValue={baseValue as number}
                        numberFormat={numberFormat}
                      />
                    </FormatRow>
                  );
                default:
                  <FormatRow isPinned={is_pinned} data_type={data_type}>
                    {value}
                  </FormatRow>;
              }
            },
          };
        }) as GridColDef[],
    [base, columns, treeData]
  );

  const rows = useMemo(
    () =>
      data
        ? columns.reduce(
            (acc, cur) => {
              const { name } = cur;

              data[name]?.forEach(
                (value, index) =>
                  (acc[index] = {
                    id: index,
                    ...acc[index],
                    type: data.id?.[index] ?? 'row',
                    [name]: value,
                  })
              );

              return acc;
            },
            [] as Record<string, string | number>[]
          )
        : [],
    [data, columns]
  );

  const groupRows = useMemo(
    () =>
      rows?.flatMap((a) => {
        const path = treeData?.groupColumns
          .map((c) => a[c])
          .filter((i) => i !== 'all');

        return {
          ...(treeData?.groupColumns ? _.omit(a, treeData?.groupColumns) : a),
          path: path?.length ? path : ['Total'],
        };
      }) as unknown as Record<string, string | number | string[]>[],
    [rows, treeData]
  );

  const pinnedRow = groupRows.find((r) => r.type === 'Total');

  return (
    <div className="h-screen pb-6">
      <DataGridPremium
        treeData={!!treeData}
        groupingColDef={
          treeData
            ? {
                headerName: treeData.headerName,
                width: 340,
                hideDescendantCount: true,
              }
            : {}
        }
        getTreeDataPath={(row) =>
          treeData ? (row.path?.length ? row.path : ['']) : [row.id]
        }
        pinnedRows={
          pinnedRow && {
            bottom: [
              {
                ...pinnedRow,
                path: ['Total'],
              },
            ],
          }
        }
        apiRef={apiRef}
        rows={groupRows?.length ? groupRows : []}
        columns={cols}
        components={{
          NoRowsOverlay: () => (
            <div className="h-full flex justify-center text-[#666666] font-semibold items-center">
              {placeholder}
            </div>
          ),
        }}
        pagination={false}
        loading={isLoading}
        autoPageSize
        rowGroupingColumnMode="multiple"
        getRowClassName={(params) => {
          const maxLevel = treeData?.groupColumns.length;
          const groupLevel = params.row.path?.length;

          const rowType = params.row.type;

          if (maxLevel === 3) {
            switch (groupLevel) {
              case 1:
                return 'first-level-group-row';
              case 2:
                return 'second-level-group-row';
              default:
                return '';
            }
          } else if (maxLevel === 2) {
            switch (groupLevel) {
              case 1:
                return 'first-level-group-row';
              default:
                return '';
            }
          }

          if (rowType === 'Total') {
            return 'total-row';
          }

          return '';
        }}
        sx={{
          '& .MuiDataGrid-columnHeaders': {
            backgroundColor: '#01285F',
            color: '#FFFFFF !important',
            fontWeight: 'bold',
          },

          '& .MuiDataGrid-pinnedColumnHeaders': {
            backgroundColor: '#01285F',
          },

          '& .MuiDataGrid-row.first-level-group-row': {
            backgroundColor: '#B1D0FC',
            fontWeight: 'bold',
          },

          '& .MuiDataGrid-row.second-level-group-row': {
            backgroundColor: '#CEE2FF',
          },

          '& .MuiDataGrid-row.total-row': {
            backgroundColor: '#B1D0FC',
            fontWeight: 'bold',
          },
        }}
        hideFooter
        disableAggregation
        density="compact"
        initialState={{
          aggregation: { model: { gross: 'sum' } },
        }}
        pinnedColumns={{
          left: cols
            .filter((col) => col.pinnable)
            .map((col) => col.field)
            .concat(GRID_TREE_DATA_GROUPING_FIELD),
        }}
        experimentalFeatures={{ rowPinning: true }}
        defaultGroupingExpansionDepth={2}
      />
    </div>
  );
};

export default DataGridPremiumTable;
