import * as d3 from 'd3';
import BasePlotBuilder, {
    PlotMargin,
    ConstructorParams as BaseParams,
} from '@/src/components/plots/builders/BasePlotBuilder';
import { ArrowPlotData, GenericCellData, PlotMapping, RidgePlotItem } from '@models/ExperimentData';
import RidgePlotDisplayOption from '@/src/models/plotDisplayOption/RidgePlotDisplayOption';
import { getPlotPalette } from '@/src/components/ColorPaletteUtil';
import { wrapTextNode } from '@/src/components/plots/PlotUtil';
import { formatStringToNumberWithSeparator, roundToDecimal } from '@/src/util/StringUtil';

type DensityArray = [number, number];
type Density = DensityArray[];

const bandwidth = 0.45;
const compression_factor = 0.65; // used to determine how much to compress values vertically at the bottom of a Ridge; set to 1 for no compression, 0.5 for 50% compression, etc.
const amplitude = 100; // used to scale the height of the Ridge;

interface GroupedItem {
    group_id: string;
    group_name: string;
    value: number;
}

export type ConstructorParams = BaseParams<ArrowPlotData<GenericCellData>>;
export default class RidgePlotBuilder extends BasePlotBuilder<ArrowPlotData<GenericCellData>> {
    constructor(params: ConstructorParams) {
        super(params);
    }

    calculateMargins(): PlotMargin {
        return {
            top: 135,
            right: 30,
            bottom: 100,
            left: 110,
        };
    }

    getTooltipContent = (
        analysisShortname: string,
        groupName: string,
        groupCellCount: number,
        percentExpressed: number,
    ): string => {
        const getPercentExpressedString = () => {
            switch (analysisShortname) {
                case 'seurat_module_score':
                    return 'with module activity';
                case 'seurat_marker_expression':
                default:
                    return 'expressing target';
            }
        };
        return `
<span class="block font-semibold text-dark">Group: ${groupName}</span>
<span class="block font-semibold text-dark">Total number of cells: ${formatStringToNumberWithSeparator(groupCellCount)}</span>
<span class="block font-semibold text-dark">% of cells ${getPercentExpressedString()}: ${roundToDecimal(percentExpressed, { decimals: 1 })}%</span>
`;
    };

    draw = () => {
        const items = this.data.items as RidgePlotItem[];
        const display = this.plot.display as RidgePlotDisplayOption;
        const shortname = this.plot.analysis_type;
        const dataMap = this.data.plot_mapping as PlotMapping<RidgePlotItem>;
        const customColors = display.custom_color_json ?? {};
        const theme = display.theme_color;
        const customLegend = display.custom_legend_json ?? {};
        const groupDisplayOrder = display.group_display_order ?? [];
        const showCellsGroups = display.groups ?? {};
        const stylingOptions = this.stylingOptions;
        const publicationMode = this.publicationMode;
        const tooltipContainer = this.tooltip;

        const svg = this.svg;
        const margin = this.calculateMargins();
        const width = this.width - margin.left - margin.right;
        const height = this.height - margin.top - margin.bottom;

        //// helper functions ////
        const kernelDensityEstimator = (kernel: (x: number) => number, X: number[]) => {
            return (V: number[]): Density => X.map((x: number) => [x, d3.mean(V, (v: number) => kernel(x - v)) ?? 0]);
        };

        const kernelEpanechnikov = (k: number) => {
            return (v: number) => (Math.abs((v /= k)) <= 1 ? (0.75 * (1 - v * v)) / k : 0);
        };

        const filterGroup = (d: [string, GroupedItem[]]) => {
            const hasGroups = showCellsGroups && Object.keys(showCellsGroups).length > 0;
            const groupId = d[1][0].group_id;
            const groupIsHidden = !showCellsGroups?.[groupId];
            if (hasGroups && groupIsHidden) {
                return false;
            }
            return true;
        };

        const getAreaColor = (groupId: string, i: number) => {
            const { colors } = getPlotPalette(theme);
            return customColors?.[groupId] ?? colors[i % colors.length]?.color;
        };
        const getCompressionThreshold = (allValues: number[]) => {
            // used to determine if a compression factor should be applied to the Ridge; set to 0.75 to compress down values that are less than or equal to 0.75
            const nonZeroValues = allValues.filter((value) => value > 0);
            if (nonZeroValues.length === 0) {
                console.warn('Ridge plot: No non-zero values available for percentile calculation.');
                return 0;
            }
            const sortedValues = nonZeroValues.sort((a, b) => a - b);
            const percentile10 = sortedValues[Math.floor(0.1 * sortedValues.length)];
            return percentile10;
        };

        const compress = (d: DensityArray, compression_threshold: number) =>
            d[0] < compression_threshold ? d[1] * compression_factor : d[1];

        // end helper functions //

        // Group data by group_name
        const groupedData = d3.group(
            items.map((item) => ({
                group_id: item[dataMap.group_id],
                group_name: customLegend?.[item[dataMap.group_id]] ?? item[dataMap.group_name],
                value: +item[dataMap.value],
            })),
            (d) => d.group_name,
        );
        const sortedGroupedData = Array.from(groupedData.entries()).sort(
            (g1, g2) => groupDisplayOrder.indexOf(g1[1][0].group_id) - groupDisplayOrder.indexOf(g2[1][0].group_id),
        );
        const filteredGroupedData = sortedGroupedData.filter(filterGroup);
        const filteredKeysSet = new Set(filteredGroupedData.map(([k]) => k));
        const filteredKeysArray = Array.from(filteredKeysSet);

        // Set the dimensions and margins of the graph
        this._svg?.attr('width', this.width).attr('height', this.height);
        svg.attr('transform', `translate(${margin.left},${margin.top})`);

        // X scale
        const allValues = sortedGroupedData
            .map((d) => d[1])
            .flat()
            .map((d) => d.value);
        const globalMaxValue = d3.max(allValues) ?? 0;
        const x = d3
            .scaleLinear()
            .domain([-1, globalMaxValue + 1]) // Cover the full range of your data
            .range([0, width]);
        const compression_threshold = getCompressionThreshold(allValues);

        // Y scale
        const y = d3
            .scaleBand()
            .domain(filteredKeysArray) // Use the sorted keys
            .range([height, 0])
            .padding(0.1);

        // Function to get label text based on a condition
        const getLabelText = () => {
            switch (shortname) {
                case 'seurat_module_score':
                    return 'Module score';
                case 'seurat_marker_expression':
                default:
                    return 'Log-normalized expression';
            }
        };

        // Axis labels
        svg.append('text')
            .attr('class', `axis-label y-axis-label`)
            .attr('x', -30)
            .attr('y', -10)
            .attr('fill', stylingOptions?.yaxis?.fontColor || (publicationMode ? 'black' : 'currentColor'))
            .style('font-size', stylingOptions?.yaxis?.fontSize || 18)
            .style('font-family', stylingOptions?.yaxis?.fontFamily || 'Arial')
            .attr('text-anchor', 'middle')
            .text('Group')
            .call((g) => wrapTextNode(g, (height - margin.bottom) * 0.9));

        const labelHeight = svg.select<SVGGElement>('.x-axis-label')?.node()?.getBoundingClientRect().height ?? 22;
        svg.append('text')
            .attr('class', `axis-label x-axis-label`)
            .attr('x', width / 2)
            .attr('y', height + labelHeight + 20)
            .attr('fill', stylingOptions?.xaxis?.fontColor || (publicationMode ? 'black' : 'currentColor'))
            .style('font-size', stylingOptions?.xaxis?.fontSize || 18)
            .style('font-family', stylingOptions?.xaxis?.fontFamily || 'Arial')
            .attr('text-anchor', 'middle')
            .text(getLabelText())
            .call((g) => wrapTextNode(g, width * 0.8));

        const kde = kernelDensityEstimator(kernelEpanechnikov(bandwidth), x.ticks(50));

        // Store the original order index for coloring
        const colorIndexMap = new Map();
        sortedGroupedData.forEach(([key], i) => {
            colorIndexMap.set(key, i);
        });

        // Reverse the data so bottom ridges are drawn on top
        sortedGroupedData.reverse().forEach(([key, values]) => {
            if (!filteredKeysSet.has(key)) return;

            const density: Density = kde(values.map((d) => d.value));

            const yKeyCalc = (y(key) ?? 0) - y.bandwidth() / 2;
            const groupId = values[0][dataMap.group_id] as string;

            const numberOfValuesExpressed = values.filter((d) => d.value > 0).length; // Count of non-zero values for this group
            const totalNumberOfValues = values.length; // Total count of values for this group
            const percentExpressed = (numberOfValuesExpressed / totalNumberOfValues) * 100; // Percent of expressed values for this group
            const groupName = key;

            // Use the original index for coloring
            const originalIndex = colorIndexMap.get(key);

            // Draw the filled area
            svg.append('path')
                .attr('class', 'density-area')
                .attr('transform', `translate(0,${yKeyCalc})`)
                .attr('fill', () => getAreaColor(groupId, originalIndex))
                .attr('opacity', 0.8)
                .on('mouseover', (event) => {
                    tooltipContainer.transition().duration(50).style('opacity', 1);
                    tooltipContainer
                        .html(this.getTooltipContent(shortname, groupName, totalNumberOfValues, percentExpressed))
                        .style('left', `${event.pageX + 10}px`)
                        .style('top', `${event.pageY - 10}px`);
                })
                .on('mouseout', function () {
                    tooltipContainer.transition().style('opacity', 0);
                })
                .datum(density.map((p) => [p[0], compress(p, compression_threshold)]))
                .attr(
                    'd',
                    d3
                        .area<number[]>()
                        .curve(d3.curveBasis)
                        .x((p) => x(p[0]))
                        .y1((p) => y.bandwidth() - p[1] * amplitude)
                        .y0(y.bandwidth()),
                );

            // Draw the top line (stroke) after the fill
            svg.append('path')
                .attr('class', 'density-line')
                .attr('transform', `translate(0,${yKeyCalc})`)
                .on('mouseover', (event) => {
                    tooltipContainer.transition().duration(50).style('opacity', 1);
                    tooltipContainer
                        .html(this.getTooltipContent(shortname, groupName, totalNumberOfValues, percentExpressed))
                        .style('left', `${event.pageX + 10}px`)
                        .style('top', `${event.pageY - 10}px`);
                })
                .on('mouseout', function () {
                    tooltipContainer.transition().style('opacity', 0);
                })
                .datum(density.map((p) => [p[0], compress(p, compression_threshold)]))
                .attr('fill', 'none')
                .attr('stroke', '#000')
                .attr('stroke-width', 1.5)
                .attr(
                    'd',
                    d3
                        .line<number[]>()
                        .curve(d3.curveBasis)
                        .x((p) => x(p[0]))
                        .y((p) => y.bandwidth() - p[1] * amplitude),
                );
        });

        // Add the axes last to ensure they are on top
        const xAxis = svg
            .append('g')
            .attr('class', 'x axis')
            .attr('transform', `translate(0,${height})`)
            .call(d3.axisBottom(x).tickSizeOuter(0));
        xAxis.selectAll('.domain').attr('stroke', publicationMode ? 'black' : 'currentColor');
        xAxis.selectAll('.tick text').attr('fill', publicationMode ? 'black' : 'currentColor');

        const yAxis = svg
            .append('g')
            .attr('class', 'y axis')
            .call(
                d3
                    .axisLeft(y)
                    .tickSizeOuter(0)
                    .tickFormat((value) => {
                        const maxLength = 20;

                        if (value.length > maxLength && !this.isExportMode) {
                            return `${value.substring(0, maxLength - 3)}...`;
                        }
                        return value;
                    }),
            );
        yAxis.selectAll('.domain').attr('stroke', publicationMode ? 'black' : 'currentColor');
        yAxis.selectAll('.tick text').attr('fill', publicationMode ? 'black' : 'currentColor');
    };
}
