import BasePlotBuilder, { ConstructorParams, PlotMargin } from '@components/plots/builders/BasePlotBuilder';
import { ComparativeAnalysisPlotData, DEGSample } from '@models/ExperimentData';
import {
    axisBottom,
    axisLeft,
    BaseType,
    max,
    min,
    scaleBand,
    ScaleBand,
    scaleLinear,
    ScaleLinear,
    Selection,
} from 'd3';
import cn from 'classnames';
import {
    AXIS_LABEL_CLASSNAMES,
    AXIS_LABEL_PUBLICATION_CLASSNAMES,
    AXIS_TITLE_CLASSNAMES,
    AXIS_TITLE_PUBLICATION_CLASSNAMES,
    ComparativeHeatmapColorScale,
} from '@models/PlotConfigs';
import { formatTableHeader, roundToDecimal } from '@util/StringUtil';
import { rotateXAxisLabels as drawRotatedXAxisLabels, wrapTextNode } from '@components/plots/PlotUtil';
import { getComparativeHeatmapColorScale, getPlotPalette } from '@components/ColorPaletteUtil';
import Plot from '@models/Plot';
import HeatmapDisplayOption from '@models/plotDisplayOption/HeatmapDisplayOption';
import { AnalysisGroup } from '@models/AnalysisParameters';
import { isDefined } from '@util/TypeGuards';
import { ComparativeAnalysis } from '@models/analysis/ComparativeAnalysis';

type ConstructorArgs = ConstructorParams<ComparativeAnalysisPlotData, PlotType> & {
    controlGroup: AnalysisGroup;
    experimentalGroup: AnalysisGroup;
};

type Sample = { value: number; group_name: string; group_id: number; target_name: string; log2_fold_change: number };

type PlotType = Plot<ComparativeAnalysis, HeatmapDisplayOption>;
export default class ComparativeHeatmapBuilder extends BasePlotBuilder<ComparativeAnalysisPlotData, PlotType> {
    xAxisLabelRotation = 45;
    legendWidth = 16;
    scales: {
        x: ScaleBand<string>;
        y: ScaleBand<string>;
        color: ScaleLinear<string, number>;
    };
    controlGroup: AnalysisGroup;
    experimentalGroup: AnalysisGroup;
    targetNames: string[];
    sortedData: DEGSample[];

    constructor(options: ConstructorArgs) {
        super(options);
        this.controlGroup = options.controlGroup;
        this.experimentalGroup = options.experimentalGroup;
        const pValueThreshold = this.plot.display.adj_p_value_threshold ?? 0.001;
        const sortedData = [...this.data.items]
            .filter((d) => d.Adj_P_Value <= pValueThreshold)
            .sort((s1, s2) => {
                return s2.Log2_Fold_Change - s1.Log2_Fold_Change;
            });
        this.sortedData = sortedData;
        this.targetNames = sortedData.map(this.getDisplayname);
        this.scales = this.makeScales();
    }

    getExperimentalValue = (d: DEGSample): number => {
        const summarizeValuesBy = this.plot.display.summarize_values_by;

        switch (summarizeValuesBy) {
            case 'mean':
                return d.Average_ZscoreLog2CPM_Experimental ?? d.Average_Zscore_Experimental;
            case 'median':
                return d.Median_ZscoreLog2CPM_Experimental ?? d.Median_Zscore_Experimental;
            default:
                throw new Error(`Unable to group by "${summarizeValuesBy}"`);
        }
    };

    getControlValue = (d: DEGSample): number => {
        const summarizeValuesBy = this.plot.display.summarize_values_by;

        switch (summarizeValuesBy) {
            case 'mean':
                return d.Average_ZscoreLog2CPM_Control ?? d.Average_Zscore_Control;
            case 'median':
                return d.Median_ZscoreLog2CPM_Control ?? d.Median_Zscore_Control;
            default:
                throw new Error(`Unable to group by "${summarizeValuesBy}"`);
        }
    };

    get maxSampleValue() {
        return max(this.sortedData, (d) => Math.max(this.getControlValue(d), this.getExperimentalValue(d))) ?? 0;
    }

    get minSampleValue() {
        return min(this.sortedData, (d) => Math.min(this.getControlValue(d), this.getExperimentalValue(d))) ?? 0;
    }

    get yDomain(): { yMin: number; yMax: number } {
        const yMin = Math.min(this.minSampleValue, 0);
        const yMax = Math.max(this.maxSampleValue, 0);
        return { yMin, yMax };
    }

    get yBreakpoints() {
        const { yMax, yMin } = this.yDomain;
        const yRange = yMax - yMin;
        const point25 = yMin + yRange / 4;
        const point50 = yMin + yRange / 2;
        const point75 = yMin + (3 * yRange) / 4;

        const breakpoints = { yMin, point25, point50, point75, yMax };
        return breakpoints;
    }

    makeScales = () => {
        const margin = this.margin;

        const targetNames = this.targetNames;
        const height = this.height;
        const extraX = this.legendWidth * 2;
        const groupNames = this.getSortedGroups().map((g) => g.display_name);
        // const sampleIds = [...new Set(this.allSamples.map((d) => d.sample_id))];

        const width = this.width;
        const xScaleBand = scaleBand()
            .domain(targetNames)
            .range([margin.left + extraX, width - margin.right]);

        const yScaleBand = scaleBand()
            .domain(groupNames)
            .range([height - margin.bottom, margin.top]);
        // const { yMax, yMin } = this.yDomain;
        const { yMax, yMin, point75, point25, point50 } = this.yBreakpoints;

        const heatmapColors = getComparativeHeatmapColorScale(
            this.plot.display.heatmap_scale_color as ComparativeHeatmapColorScale,
        );
        const colorScale = scaleLinear<string, number>()
            .range(heatmapColors.map((c) => c.color))
            .domain([yMin, point25, point50, point75, yMax]);

        return { x: xScaleBand, y: yScaleBand, color: colorScale };
    };

    calculateMargins(): PlotMargin {
        return {
            top: 100,
            right: 55,
            bottom: 90,
            left: 110,
        };
    }

    draw(): void {
        this.appendYAxis();
        this.appendXAxis();
        const yAxisWidth = this.svg.select<SVGGElement>('.y-axis')?.node()?.getBoundingClientRect().width ?? 0;
        const xAxisHeight = this.svg.select<SVGGElement>('.x-axis')?.node()?.getBoundingClientRect().height ?? 0;
        this.margin.left = yAxisWidth + 10;
        this.margin.bottom = xAxisHeight + 20;
        this.scales = this.makeScales();
        this.appendYAxis();

        this.appendXAxis();
        this.drawGroups();
        this.drawGroupLegend();

        this.drawScaleLegend();

        this.svg.on('mouseout', () => {
            this.tooltip.style('opacity', 0);
        });
        return;
    }

    get palette() {
        return getPlotPalette(this.themeColor);
    }

    get legendId() {
        return `${this.plot.uuid}-linear-gradient`;
    }

    get plotWidth() {
        return this.width - this.margin.left - this.margin.right;
    }

    get unitsLabel(): string {
        switch (this.plot.display.summarize_values_by) {
            case 'median':
                return `Z-score median expression`;
            case 'mean':
                return 'Z-score average expression';
            case 'none':
            default:
                return '';
        }
    }

    drawScaleLegend = () => {
        const width = this.plotWidth / (this.plot.display?.is_full_width ? 2 : 1);
        const rectHeight = 20;
        const spacing = 40;
        const colorScale = this.scales.color;
        const colors = colorScale.range();
        const publicationMode = this.publicationMode;
        this.svg.select(`.scale-legend-gradient`).remove();
        const defs = this.svg.append('defs').attr('class', 'scale-legend-gradient');
        const gradient = defs.append('linearGradient').attr('id', this.legendId);
        gradient.attr('x1', '0%').attr('y1', '0%').attr('x2', '100%').attr('y2', '0%');

        colors.forEach((hex, i) => {
            gradient
                .append('stop')
                .attr('offset', `${i * 25}%`)
                .attr('stop-color', hex);
        });
        this.svg.select('.scale-legend-label').remove();
        this.svg
            .append('text')
            .text(this.unitsLabel ?? 'Unknown units')
            .attr('x', this.margin.left + width / 2)
            .attr('y', this.margin.top - rectHeight - spacing)
            .attr('text-anchor', 'middle')
            .attr('fill', publicationMode ? 'black' : 'currentColor')
            .attr('class', cn('scale-legend-label', { 'font-semibold': publicationMode }));

        const labelOffset = 6;

        this.svg.select('.legend-gradient').remove();

        this.svg
            .append('rect')
            .attr('class', 'legend-gradient')
            .attr('x', this.margin.left)
            .attr('y', this.margin.top - rectHeight - spacing + labelOffset)
            .attr('width', width)
            .attr('height', rectHeight)
            .style('fill', `url(#${this.legendId})`);

        const { yMin, yMax } = this.yDomain;
        const legendScale = scaleLinear()
            .domain([yMin, yMax])
            .range([0, width - 1]);

        const ticks = Object.values(this.yBreakpoints).sort();

        this.svg.select('.legend-x-axis').remove();
        const drawLegendScale = (g: Selection<SVGGElement, unknown, BaseType, unknown>) => {
            g.attr('transform', `translate(${this.margin.left},${this.margin.top - spacing + labelOffset})`)
                .attr(
                    'class',
                    cn(publicationMode ? AXIS_LABEL_PUBLICATION_CLASSNAMES : AXIS_LABEL_CLASSNAMES, 'legend-x-axis'),
                )
                .call(axisBottom(legendScale).tickSizeOuter(30).tickValues(ticks));

            g.select('.domain').remove();

            return g;
        };

        this.svg.append('g').call(drawLegendScale);
    };

    getDisplayname = (d: DEGSample) => {
        switch (this.plot.analysis?.analysis_type ?? this.plot.analysis_type) {
            case 'differential_binding':
                return d.peak_id ?? d.Gene_Symbol ?? d.gene_id;
            case 'differential_expression':
            default:
                return d.Gene_Symbol ?? d.gene_id ?? d.peak_id;
        }
    };

    makeFlatSamples = (): Sample[] => {
        return this.sortedData.reduce<Sample[]>((samples, item) => {
            samples.push({
                group_id: this.experimentalGroup.id,
                group_name: this.experimentalGroup.display_name,
                value: this.getExperimentalValue(item),
                target_name: this.getDisplayname(item),
                log2_fold_change: item.Log2_Fold_Change,
            });

            samples.push({
                group_id: this.controlGroup.id,
                group_name: this.controlGroup.display_name,
                value: this.getControlValue(item),
                target_name: this.getDisplayname(item),
                log2_fold_change: item.Log2_Fold_Change,
            });
            return samples;
        }, []);
    };

    drawGroups = () => {
        const scales = this.scales;
        const stats = this.makeFlatSamples();
        const tooltipContainer = this.tooltip;
        const valueName = this.plot.display.summarize_values_by === 'mean' ? 'average' : 'median';

        this.svg
            .selectAll()
            .data(stats)
            .enter()
            .append('rect')
            .attr('x', (d) => scales.x(d.target_name) ?? 0)
            .attr('y', (d) => scales.y(d.group_name) ?? 0)
            .attr('width', scales.x.bandwidth() + 1)
            .attr('height', scales.y.bandwidth() + 1)
            .style('fill', (d) => {
                const value = d.value;
                if (!isDefined(value)) {
                    return 'transparent';
                }
                return scales.color(value);
            })
            .on('mousemove', function (event, d) {
                tooltipContainer.style('opacity', 1);
                tooltipContainer
                    .html(
                        `
                                <div class="white-space-normal">
        <span class="block font-semibold text-dark">${d.target_name}</span>
        <span class="block text-sm text-gray-600">group: ${d.group_name}</span>
        <span class="block text-sm text-gray-600">${valueName}: ${roundToDecimal(d.value)}</span>
        </div>`,
                    )
                    .style('left', `${event.pageX + 15}px`)
                    .style('top', `${event.pageY - 10}px`);
                this.parentNode?.appendChild(this);
            });
    };

    getSortedGroups = () => {
        // map over groups to change group display names if custom_legend_json exists
        const newGroups = [this.controlGroup, this.experimentalGroup].map((group: AnalysisGroup) => {
            if (this.plot.display.custom_legend_json && this.plot.display.custom_legend_json[group.id]) {
                group.display_name = this.plot.display.custom_legend_json[group.id];
            }
            return group;
        });
        return newGroups;
    };

    drawGroupLegend = () => {
        const groups = this.getSortedGroups();
        const scales = this.scales;
        const customColors = this.plot.display.custom_color_json ?? {};
        const sortedGroups = this.getSortedGroups();
        const palette = this.palette;
        const getColor = (group: AnalysisGroup): string => {
            if (customColors[`${group.id}`]) {
                return customColors[`${group.id}`];
            }
            const groupIndex = sortedGroups?.findIndex((g) => g.display_name === group.display_name) ?? 0;
            return palette.colors[groupIndex % (palette.colors.length - 1)].color;
        };

        const groupX = this.margin.left;

        this.svg
            .selectAll()
            .data(groups)
            .enter()
            .append('rect')
            .attr('x', groupX)
            .attr('y', (d) => scales.y(d.display_name) ?? 0)
            .attr('width', this.legendWidth)
            .attr('height', scales.y.bandwidth() + 1)
            .style('fill', (d) => {
                return getColor(d);
            });
    };

    appendYAxis = () => {
        // const scales = this.makeHeatmapScales({ margin: this.margin });
        const scales = this.scales;
        const height = this.height;
        const margin = this.margin;
        const publicationMode = this.publicationMode;
        this.svg.select('.y-axis').remove();
        const drawYAxis = (g: Selection<SVGGElement, unknown, BaseType, unknown>) => {
            const yAxisConfig = axisLeft(scales.y).tickSizeOuter(0);

            g.call((g) => g.select('.domain').remove())
                .attr('transform', `translate(${margin.left},0)`)
                .attr(
                    'class',
                    cn(publicationMode ? AXIS_LABEL_PUBLICATION_CLASSNAMES : AXIS_LABEL_CLASSNAMES, 'y-axis'),
                )
                .call(yAxisConfig)
                .call((g) =>
                    g
                        .append('text')
                        .attr('x', -height / 2)
                        .attr('y', -margin.left + 20)
                        .attr('fill', 'currentColor')
                        .attr('text-anchor', 'middle')
                        .attr('transform', 'rotate(-90)')
                        .attr('class', publicationMode ? AXIS_TITLE_PUBLICATION_CLASSNAMES : AXIS_TITLE_CLASSNAMES),
                );

            const labels = g.selectAll<SVGTextElement, unknown>('.tick text');

            labels.call(wrapTextNode, 160);

            g.select('.domain').remove();
        };

        this.svg.append('g').call(drawYAxis);
    };

    appendXAxis = () => {
        const { x: xScale } = this.scales;
        const height = this.height;
        const margin = this.margin;
        const publicationMode = this.publicationMode;
        const labelRotation = this.xAxisLabelRotation;
        this.svg.selectAll('.x-axis').remove();

        if (this.sortedData.length > 50) {
            return;
        }

        const drawXAxis = (g: Selection<SVGGElement, unknown, BaseType, unknown>) => {
            g.attr('transform', `translate(0,${height - margin.bottom})`)
                .attr(
                    'class',
                    cn(publicationMode ? AXIS_LABEL_PUBLICATION_CLASSNAMES : AXIS_LABEL_CLASSNAMES, 'x-axis'),
                )
                .call(
                    axisBottom(xScale)
                        .tickSize(12)
                        .tickSizeOuter(0)
                        .tickFormat((label) => formatTableHeader(label)),
                );
            const labels = g.selectAll<SVGTextElement, unknown>('.tick text');

            labels.call(wrapTextNode, 160);
            if (labelRotation) {
                drawRotatedXAxisLabels(g, labelRotation);
            }
            g.select('.domain').remove();

            return g;
        };

        this.svg.append('g').call(drawXAxis);
    };
}
