import * as d3 from 'd3';
import Plot from '@models/Plot';
import Experiment from '@models/Experiment';
import SampleScatterPlotDisplayOption from '@models/plotDisplayOption/SampleScatterPlotDisplayOption';
import { ExperimentAnalysis } from '@models/analysis/ExperimentAnalysis';
import { getPlotPalette } from '@components/ColorPaletteUtil';
import {
    DataPoint,
    PreparedScatterPlotData,
    prepareSampleScatterPlotData,
} from '@components/plots/PrincipalComponentsAnalysisPlotUtil';
import { PCAData, PCASample } from '@models/ExperimentData';
import Logger from '@util/Logger';
import { Selection } from 'd3';
import { AXIS_PADDING_PERCENT, createPlotTooltip, wrapTextNode } from '@components/plots/PlotUtil';
import { isDefined } from '@util/TypeGuards';
import {
    AXIS_LABEL_CLASSNAMES,
    AXIS_LABEL_PUBLICATION_CLASSNAMES,
    AXIS_TITLE_CLASSNAMES,
    AXIS_TITLE_PUBLICATION_CLASSNAMES,
} from '@models/PlotConfigs';
import cn from 'classnames';
import { PaletteColor } from '@components/PaletteColors';
import { CustomPlotStylingOptions } from '@components/analysisCategories/comparative/plots/PlotlyVolcanoPlotUtil';

const logger = Logger.make('SampleScatterPlotBuilder');

type SVGSelection = d3.Selection<SVGSVGElement, unknown, d3.BaseType, unknown>;
type ConstructorParams = {
    svg: SVGSelection;
    plot: Plot<ExperimentAnalysis, SampleScatterPlotDisplayOption>;
    experiment: Experiment;
    data: PCAData;
    size: { height: number; width: number };
    getProportionOfVariance: (column: string) => number | null;
    publicationMode?: boolean;
    tooltipId: string;
    stylingOptions: CustomPlotStylingOptions | null;
};
type BuilderParams = ConstructorParams;
export default class SampleScatterPlotBuilder {
    svg: SVGSelection;
    plot: Plot<ExperimentAnalysis, SampleScatterPlotDisplayOption>;
    experiment: Experiment;
    data: PCAData;
    preparedData: PreparedScatterPlotData | null = null;
    tooltipId: string;
    size: { height: number; width: number };
    tooltipContainer: Selection<HTMLDivElement, unknown, HTMLElement, unknown>;
    getProportionOfVariance: (column: string) => number | null = () => null;
    publicationMode = false;
    stylingOptions: CustomPlotStylingOptions | null;

    private constructor(options: ConstructorParams) {
        this.size = options.size;
        this.svg = options.svg;
        this.plot = options.plot;
        this.experiment = options.experiment;
        this.data = options.data;
        this.tooltipId = options.tooltipId;
        this.tooltipContainer = createPlotTooltip(this.tooltipId);
        this.getProportionOfVariance = options.getProportionOfVariance;
        this.publicationMode = options.publicationMode ?? false;
        this.stylingOptions = options.stylingOptions;
    }

    static make(params: BuilderParams): SampleScatterPlotBuilder {
        return new SampleScatterPlotBuilder(params);
    }

    get options(): SampleScatterPlotDisplayOption {
        return this.plot.display;
    }

    get palette() {
        return getPlotPalette(this.options.theme_color);
    }

    legacyDraw(preparedData: PreparedScatterPlotData) {
        const svg = this.svg;
        const tooltipContainer = this.tooltipContainer;
        const allSamples = preparedData.items;
        const publicationMode = this.publicationMode;
        const options = this.options;
        const size = this.size;
        const { x_axis_column, y_axis_column, custom_color_json } = this.options;
        const customColors = custom_color_json ?? {};
        const { x: xStats, y: yStats } = preparedData;

        svg.selectAll('g').remove();
        const height = size?.height ?? 500;
        const width = size?.width ?? 500;
        const margin = { top: 20, right: 20, bottom: 46, left: 90 };
        const yPadding = Math.abs(yStats.max - yStats.min) * AXIS_PADDING_PERCENT;
        const xPadding = Math.abs(xStats.max - xStats.min) * AXIS_PADDING_PERCENT;

        const yMin = options.y_axis_start ?? yStats.min - yPadding;
        const yMax = options.y_axis_end ?? yStats.max + yPadding;

        const xMin = options.x_axis_start ?? xStats.min - xPadding;
        const xMax = options.x_axis_end ?? xStats.max + xPadding;

        const makeYScale = () => {
            return d3
                .scaleLinear()
                .domain([yMin, yMax])
                .rangeRound([height - margin.bottom, margin.top]);
        };

        const makeXScale = () =>
            d3
                .scaleLinear()
                .domain([xMin, xMax])
                .rangeRound([margin.left, width - margin.right]);

        let yScale = makeYScale();
        let xScale = makeXScale();

        svg.selectAll('.axis-label').remove();
        svg.selectAll('.y-axis').remove();

        let yAxisLabel = y_axis_column;
        let xAxisLabel = x_axis_column;
        if (options.show_proportion_of_variance) {
            const yVariance = this.getProportionOfVariance(y_axis_column);
            const xVariance = this.getProportionOfVariance(x_axis_column);
            if (isDefined(yVariance)) {
                yAxisLabel = `${y_axis_column} (${(yVariance * 100).toFixed(1)}% variance explained)`;
            }
            if (isDefined(xVariance)) {
                xAxisLabel = `${x_axis_column} (${(xVariance * 100).toFixed(1)}% variance explained)`;
            }
        }

        /* Y-Axis */
        const drawYAxis = () => {
            const styles = this.stylingOptions?.yaxis;
            // Create y-axis and append to chart
            svg.select('.y-axis-label').remove();
            svg.select('.y-axis').remove();

            svg.select<SVGGElement>('.y-axis').attr('transform', `translate(${margin.left},0)`);

            // Add y-axis title
            const label = svg
                .append('text')
                .attr(
                    'class',
                    `axis-label y-axis-label ${
                        publicationMode ? AXIS_TITLE_PUBLICATION_CLASSNAMES : AXIS_TITLE_CLASSNAMES
                    }`,
                )
                .attr('x', -(height - margin.bottom) / 2)
                .attr('y', 18)
                .attr('fill', styles ? styles.fontColor : publicationMode ? 'black' : 'currentColor')
                .style('font-size', styles ? styles.fontSize : '18')
                .style('font-family', styles ? styles.fontFamily : 'Arial')
                .attr('text-anchor', 'middle')
                .attr('transform', 'rotate(-90)')
                .text(yAxisLabel)
                .call((g) => wrapTextNode(g, (height - margin.bottom) * 0.9));

            const labelWidth = label.node()?.getBoundingClientRect()?.width ?? 18;
            const yAxisFormat = Math.abs(yMax) > 10_000 || Math.abs(yMin) < 0.0001 ? '.1e' : ',f';
            const yAxis = (g: d3.Selection<SVGGElement, unknown, d3.BaseType, any>) =>
                g
                    .call((g) => g.select('.domain').remove())
                    .attr('transform', `translate(${margin.left},0)`)
                    .attr(
                        'class',
                        `y-axis ${publicationMode ? AXIS_LABEL_PUBLICATION_CLASSNAMES : AXIS_LABEL_CLASSNAMES}`,
                    )
                    .call(d3.axisLeft(yScale).ticks(5, yAxisFormat).tickSizeOuter(0))
                    .call((g) => g.select('.y-axis > .tick:first-of-type').remove());

            // Append the y-axis
            svg.append('g').call(yAxis);
            const yAxisWidth = svg.select<SVGGElement>('.y-axis')?.node()?.getBoundingClientRect().width ?? 70;
            margin.left = yAxisWidth + labelWidth + 12;
        };

        /* X-Axis */
        const drawXAxis = () => {
            const styles = this.stylingOptions?.xaxis;
            let labelHeight = svg.select<SVGGElement>('.x-axis-label')?.node()?.getBoundingClientRect().height ?? 24;
            svg.select('.x-axis-label').remove();
            svg.select('.x-axis').remove();
            // Add x-axis title
            svg.append('text')
                .attr(
                    'class',
                    `axis-label x-axis-label ${
                        publicationMode ? AXIS_TITLE_PUBLICATION_CLASSNAMES : AXIS_TITLE_CLASSNAMES
                    }`,
                )
                .attr('x', (width + margin.left) / 2)
                .attr('y', height - labelHeight + 12)
                .attr('fill', styles ? styles.fontColor : publicationMode ? 'black' : 'currentColor')
                .style('font-size', styles ? styles.fontSize : '18')
                .style('font-family', styles ? styles.fontFamily : 'Arial')
                .attr('text-anchor', 'middle')
                .text(xAxisLabel)
                .call((g) => wrapTextNode(g, width * 0.8));

            const xAxis = (g: d3.Selection<SVGGElement, unknown, d3.BaseType, any>) => {
                return g
                    .attr('transform', `translate(0,${height - margin.bottom})`)
                    .attr(
                        'class',
                        cn('x-axis', publicationMode ? AXIS_LABEL_PUBLICATION_CLASSNAMES : AXIS_LABEL_CLASSNAMES),
                    )
                    .call(d3.axisBottom(xScale).ticks(5).tickSizeOuter(0));
            };

            svg.append('g').call(xAxis);

            labelHeight = svg.select<SVGGElement>('.x-axis-label')?.node()?.getBoundingClientRect().height ?? 22;
            const xAxisHeight = svg.select<SVGGElement>('.x-axis')?.node()?.getBoundingClientRect().height ?? 22;
            margin.bottom = xAxisHeight + labelHeight + 12;
        };

        drawYAxis();
        drawXAxis();

        yScale = makeYScale();
        xScale = makeXScale();

        drawYAxis();
        drawXAxis();

        svg.selectAll('.coordinate-line').remove();
        /* Draw vertical lines for x=0 and y=0 */
        // Vertical line (x=0)
        svg.append('line')
            .attr('class', `coordinate-line text-gray-400`)
            .attr('stroke', 'currentColor')
            .style('stroke-width', 0.25)
            .attr('x1', xScale(0))
            .attr('x2', xScale(0))
            .attr('y1', yScale(yMin))
            .attr('y2', yScale(yMax));

        // Horizontal line (y=0)
        svg.append('line')
            .attr('class', `coordinate-line text-gray-400`)
            .attr('stroke', 'currentColor')
            .style('stroke-width', 0.25)
            .attr('x1', xScale(xMin))
            .attr('x2', xScale(xMax))
            .attr('y1', yScale(0))
            .attr('y2', yScale(0));

        // Start of the dots
        const getCircleFillColor = (d: DataPoint) => customColors[`${d.group_id}`] ?? getPointPaletteColor(d).color;

        const getPointPaletteColor = (d: PCASample): PaletteColor => {
            let colorIndex = (options.group_display_order ?? []).indexOf(Number(d.group_id));
            if (colorIndex < 0) {
                colorIndex = 0;
            }
            const paletteOption = this.palette.colors;
            const paletteLength = paletteOption.length;
            return paletteOption[colorIndex % paletteLength];
        };

        const getTooltipContent = (d: DataPoint) => {
            return `
<span class="block font-semibold text-dark">${d.Sample_ID}</span>
<span class="block text-sm text-gray-600">x: ${d.x.toFixed(4)}</span>
<span class="block text-sm text-gray-600">y: ${d.y.toFixed(4)}</span>
<span class="block text-sm text-gray-600">group: ${d.group_name}</span>
`;
        };

        const DOT_RADIUS = 4.5;
        // Draw Scatter Plot
        // Add dots
        svg.append('g')
            .selectAll('dot')
            .data(allSamples)
            .enter()
            .append('circle')
            .attr('cx', function (d) {
                return xScale(d.x);
            })
            .attr('cy', function (d) {
                return yScale(d.y);
            })
            .attr('r', DOT_RADIUS)
            .attr('stroke', (d) => getCircleFillColor(d))
            .attr('stroke-width', 1)
            // .attr('class', getPointClassName)
            .style('fill', getCircleFillColor)
            .style('fill-opacity', '0.75')
            .on('mouseover', function (event, d) {
                const circle = d3.select(event.target);
                circle
                    .style('cursor', 'crosshair')
                    .style('fill-opacity', '1')
                    .style('stroke', 'red')
                    .style('fill', 'red');
                tooltipContainer
                    .style('opacity', '1')
                    .html(getTooltipContent(d))
                    .style('left', `${event.pageX + 10}px`)
                    .style('top', `${event.pageY - 10}px`);
                event.parentNode?.appendChild(circle);
            })
            .on('mouseout', function (event, d) {
                const circle = d3.select(event.target);
                circle
                    .style('fill-opacity', '.75')
                    .style('fill', getCircleFillColor(d))
                    .style('stroke', getCircleFillColor(d));
                tooltipContainer.style('opacity', 0);
            });
    }

    draw() {
        const svg = this.svg;
        svg.selectAll('g').remove();
        const preparedData = prepareSampleScatterPlotData(this.data, this.options);
        this.preparedData = preparedData;
        if (!preparedData) {
            logger.warn('no prepared data found, removing chart.');
            return;
        }

        this.legacyDraw(preparedData);
    }
}
