import * as d3 from 'd3';
import BasePlotBuilder, {
    PlotMargin,
    ConstructorParams as BaseParams,
} from '@/src/components/plots/builders/BasePlotBuilder';
import {
    ArrowPlotData,
    GenericCellData,
    NetworkGraphLink,
    NetworkGraphNode,
    NetworkGraphItem,
} from '@models/ExperimentData';
import { getPlotPalette } from '@/src/components/ColorPaletteUtil';
import NetworkGraphDisplayOption from '@/src/models/plotDisplayOption/NetworkGraphDisplayOption';

interface ZoomTransformWithDimensions extends d3.ZoomTransform {
    originalWidth: number;
    originalHeight: number;
}

export type ConstructorParams = BaseParams<ArrowPlotData<GenericCellData>>;
export default class ForceDirectedGraph extends BasePlotBuilder<ArrowPlotData<GenericCellData>> {
    invalidation: d3.Simulation<NetworkGraphNode, any> | null = null;

    constructor(params: ConstructorParams) {
        super(params);
    }

    componentWillUnmount() {
        if (this.invalidation) {
            this.invalidation.stop();
        }
    }

    calculateMargins(): PlotMargin {
        return { top: 0, left: 0, right: 0, bottom: 0 };
    }

    draw = () => {
        const items = this.data.items as NetworkGraphItem[];
        const display = this.plot.display as NetworkGraphDisplayOption;

        const getLinkColor = (edgeType: string | number, i: number) => {
            const customColors = display.custom_color_json ?? {};
            const theme = display.theme_color;
            const { colors } = getPlotPalette(theme);
            return customColors?.[edgeType] ?? colors[i % colors.length]?.color;
        };
        const getLinkWidth = (d: NetworkGraphLink) => {
            if (!display?.edge_types?.includes(d.edge_type)) {
                return 0;
            }

            if (display?.use_weighted_edges) {
                return d.value;
            }

            return d.value ? 1 : 0;
        };

        const getUniqueNodeNames = () => {
            const uniqueNames = new Set();

            items.forEach((item) => {
                uniqueNames.add(item.protein_a);
                uniqueNames.add(item.protein_b);
            });

            return Array.from(uniqueNames);
        };

        // Data formatting
        const nodeNames = getUniqueNodeNames();
        const nodes = nodeNames.map((name) => ({ id: name, group: 1 }));
        const links = items.map((i, key) => {
            return {
                source: i.protein_a,
                target: i.protein_b,
                value: i.edge_value,
                i: key,
                edge_type: i.edge_type?.split(' ').join('_'),
            };
        }) as NetworkGraphLink[];

        const prepareLinks = () => {
            const linksFromNodes: { [key: string]: number[] } = {};
            return links.map((val: NetworkGraphLink, idx) => {
                const linkWidth = 0.3;
                const sid = val.source;
                const tid = val.target;
                const key = sid < tid ? sid + ',' + tid : tid + ',' + sid;
                if (linksFromNodes[key] === undefined) {
                    linksFromNodes[key] = [idx];
                    val.multiIdx = 1;
                } else {
                    val.multiIdx = linksFromNodes[key].push(idx);
                }
                // Calculate target link distance, from the index in the multiple-links array:
                // 1 -> 0, 2 -> 2, 3-> -2, 4 -> 4, 5 -> -4, ...
                val.targetDistance =
                    val.multiIdx % 2 === 0 ? val.multiIdx * linkWidth : (-val.multiIdx + 1) * linkWidth;
                return val;
            });
        };
        const preparedLinks = prepareLinks();

        // Create link lines
        const link = this.svg
            .append('g')
            .selectAll()
            .data(preparedLinks)
            .join('line')
            .attr('stroke', (d: NetworkGraphLink) => getLinkColor(d.edge_type, d.i))
            .attr('stroke-opacity', 1)
            .attr('stroke-width', (d) => getLinkWidth(d));

        /**
         * @param {number} targetDistance
         * @param {x,y} point0
         * @param {x,y} point1, two points that define a line segmemt
         * @returns
         * a translation {dx,dy} from the given line segment, such that the distance
         * between the given line segment and the translated line segment equals
         * targetDistance
         */
        const calcTranslationExact = (
            targetDistance: number,
            point0: { x: number; y: number },
            point1: { x: number; y: number },
        ) => {
            const x1_x0 = point1.x - point0.x;
            const y1_y0 = point1.y - point0.y;
            let x2_x0: number, y2_y0: number;

            if (targetDistance === 0) {
                x2_x0 = y2_y0 = 0;
            } else if (y1_y0 === 0 || Math.abs(x1_x0 / y1_y0) > 1) {
                y2_y0 = -targetDistance;
                x2_x0 = (targetDistance * y1_y0) / x1_x0;
            } else {
                x2_x0 = targetDistance;
                y2_y0 = (targetDistance * -x1_x0) / y1_y0;
            }

            return {
                dx: x2_x0,
                dy: y2_y0,
            };
        };

        // Tick action handler
        const ticked = () => {
            for (let i = 0; i < 5; i++) {
                simulation.tick();
            }

            link.attr('x1', (d: any) => d.source.x)
                .attr('y1', (d: any) => d.source.y)
                .attr('x2', (d: any) => d.target.x)
                .attr('y2', (d: any) => d.target.y)
                .attr('transform', function (d: NetworkGraphLink) {
                    const translation = calcTranslationExact(
                        d.targetDistance as number,
                        d.source as { x: number; y: number },
                        d.target as { x: number; y: number },
                    );
                    return `translate (${translation.dx}, ${translation.dy})`;
                });

            node.attr('transform', (d) => `translate(${(d as NetworkGraphNode).x}, ${(d as NetworkGraphNode).y})`);
        };

        const width = this.width;

        // Initiate simulation
        const simulation = d3
            .forceSimulation(nodes as NetworkGraphNode[])
            .force(
                'link',
                d3
                    .forceLink(links)
                    .id((d) => (d as NetworkGraphNode).id)
                    .distance(() => width / 4),
            )
            .force('charge', d3.forceManyBody())
            .force('center', d3.forceCenter(this.width / 2, this.height / 2))
            .force('collide', d3.forceCollide().radius(10))
            .on('tick', ticked);

        for (
            let i = 0, n = Math.ceil(Math.log(simulation.alphaMin()) / Math.log(1 - simulation.alphaDecay()));
            i < n;
            ++i
        ) {
            simulation.tick();
        }

        // Drag handles
        const dragstarted = (event: any) => {
            if (!event.active) simulation.alphaTarget(0.1).restart();
            (event.subject as any).fx = (event.subject as any).x;
            (event.subject as any).fy = (event.subject as any).y;
        };
        const dragged = (event: any) => {
            (event.subject as any).fx = event.x;
            (event.subject as any).fy = event.y;
        };
        const dragended = (event: any) => {
            if (!event.active) simulation.alphaTarget(0);
            (event.subject as any).fx = null;
            (event.subject as any).fy = null;
        };

        // Create node group with labels and drag capabilities
        const node = this.svg
            .append('g')
            .selectAll<SVGGElement, NetworkGraphNode>('.node')
            .data(nodes as NetworkGraphNode[])
            .join('g')
            .attr('class', 'node')
            .call(
                d3
                    .drag<SVGGElement, NetworkGraphNode>()
                    .on('start', dragstarted)
                    .on('drag', dragged)
                    .on('end', dragended),
            );

        node.append('circle')
            .attr('r', 8)
            .attr('fill', (d) => getLinkColor(d.group || 0, 0));

        // Add hover tooltips if labels are hidden
        if (display?.hide_labels) {
            node.append('title').text((d) => (d as NetworkGraphNode).id);
        } else {
            // Duplicate text with white outline for better readability of labels
            node.append('text')
                .attr('x', 8)
                .attr('y', '.31em')
                .style('font-size', '10px')
                .attr('stroke', '#fff')
                .attr('stroke-width', 3)
                .style('opacity', 0.8)
                .text((d) => (d as NetworkGraphNode).id);
            node.append('text')
                .text((d) => (d as NetworkGraphNode).id)
                .style('fill', '#000')
                .style('font-size', '10px')
                .attr('x', 8)
                .attr('y', '.31em');
        }

        // On hover, highlight node and direct connections
        const linkedByIndex: Record<string, number> = {};
        links.forEach(function (d: any) {
            linkedByIndex[d.source.index + ',' + d.target.index] = 1;
        });
        function isConnected(a: NetworkGraphNode, b: NetworkGraphNode) {
            return (
                linkedByIndex[b.index + ',' + a.index] || linkedByIndex[a.index + ',' + b.index] || a.index === b.index
            );
        }
        function fade(opacity: number) {
            return function (d: MouseEvent & { target: { __data__: NetworkGraphNode } }) {
                const connected: NetworkGraphNode[] = [];
                node.each((o) => {
                    const oType = o as NetworkGraphNode;
                    if (isConnected(d.target.__data__, oType)) {
                        connected.push(oType);
                    }
                });

                node.transition().style('fill-opacity', (o) => {
                    let thisOpacity = opacity;
                    connected.forEach(function (e) {
                        if (e.index === (o as NetworkGraphNode).index) {
                            thisOpacity = 1;
                        }
                    });
                    return thisOpacity;
                });

                link.transition().style('stroke-opacity', function (o: any) {
                    let thisOpacity = opacity;
                    connected.forEach(function () {
                        if (d.target.__data__.id === o.source.id || d.target.__data__.id === o.target.id) {
                            thisOpacity = 1;
                        }
                    });
                    return thisOpacity;
                });
            };
        }
        node.on('mouseover', fade(0.3)).on('mouseout', fade(1));

        const zoomed = (e: d3.D3ZoomEvent<SVGSVGElement, unknown>) => {
            this.svg.attr('transform', e.transform.toString());
            if (this.onZoomTransform) {
                const customTransform = {
                    ...e.transform,
                    originalWidth: this.width,
                    originalHeight: this.height,
                };
                this.onZoomTransform(customTransform as ZoomTransformWithDimensions);
            }
        };
        const zoom = d3.zoom<SVGSVGElement, unknown>().on('zoom', zoomed);

        // Enable zoom and pan capabilities
        if (this.zoomEnabled) {
            this._svg?.call(zoom);
        }

        // Set initial zoom after attaching zoom behavior
        const customOptionsJSON = (this.plot.display as NetworkGraphDisplayOption).custom_options_json;
        let initialTransform = d3.zoomIdentity;
        if (customOptionsJSON && !isNaN(customOptionsJSON?.zoomTransform?.k)) {
            const { x, y, k, originalWidth, originalHeight } = customOptionsJSON.zoomTransform;
            const scaledX = (x / originalWidth) * this.width;
            const scaledY = (y / originalHeight) * this.height;
            initialTransform = initialTransform.translate(scaledX, scaledY).scale(k);
        }
        this._svg?.call(zoom.transform, initialTransform);

        // When this cell is re-run, stop the previous simulation.
        simulation.on('end', () => simulation.stop());
    };
}
