import {
    forceSimulation,
    forceLink,
    forceManyBody,
    forceX,
    forceY,
    forceCollide
} from 'd3-force';
import React, { useCallback } from 'react';
import {
    useReactFlow,
} from '@xyflow/react';
import { quadtree } from 'd3-quadtree';

const D3ForceCollide = () => {
    let nodes;
    let strength = 1;
    let iterations = 1;

    function force(alpha) {
        const tree = quadtree(
            nodes,
            d => d.x,
            d => d.y
        );

        for (let k = 0; k < iterations; ++k) {
            for (const node of nodes) {
                const r = Math.max(node.width, node.height) / 2;
                const nx1 = node.x - r;
                const ny1 = node.y - r;
                const nx2 = node.x + r;
                const ny2 = node.y + r;

                tree.visit((quad, x1, y1, x2, y2) => {
                    if (!quad.length) {
                        do {
                            if (quad.data !== node) {
                                const r = (Math.max(node.width, node.height) +
                                    Math.max(quad.data.width, quad.data.height)) / 2;
                                let x = node.x - quad.data.x;
                                let y = node.y - quad.data.y;

                                const xSpacing = (node.width + quad.data.width) / 2;
                                const ySpacing = (node.height + quad.data.height) / 2;
                                const absX = Math.abs(x);
                                const absY = Math.abs(y);

                                if (absX < xSpacing && absY < ySpacing) {
                                    const l = Math.sqrt(x * x + y * y) || 1;
                                    const lx = ((absX - xSpacing) / l) * strength * alpha;
                                    const ly = ((absY - ySpacing) / l) * strength * alpha;

                                    x *= lx;
                                    y *= ly;

                                    node.x -= x;
                                    node.y -= y;
                                    quad.data.x += x;
                                    quad.data.y += y;
                                }
                            }
                        } while (quad = quad.next);
                    }
                    return x1 > nx2 || x2 < nx1 || y1 > ny2 || y2 < ny1;
                });
            }
        }
    }

    force.initialize = _ => nodes = _;

    force.strength = function (_) {
        return arguments.length ? (strength = +_, force) : strength;
    };

    force.iterations = function (_) {
        return arguments.length ? (iterations = +_, force) : iterations;
    };

    return force;
};

const createSimulation = () => forceSimulation()
    .force('charge', forceManyBody().strength(d => -50 * Math.sqrt(d.width * d.height)))
    .force('x', forceX().x(0).strength(0.05))
    .force('y', forceY().y(0).strength(0.05))
    .force('collide', D3ForceCollide().strength(0.8))
    .alphaTarget(0.05)
    .stop();

export const useLayoutedElements = () => {
    const { getNodes, setNodes, getEdges, fitView } = useReactFlow();

    const getLayoutedElements = useCallback((iterations = 500) => {
        const nodes = getNodes()
            .map((node) => ({
                ...node,
                x: node.position.x,
                y: node.position.y,
                width: node.measured?.width || 340,  // 默认宽度
                height: node.measured?.height || 150, // 默认高度
            }));
        let edges = getEdges();

        edges = edges.filter(
            (edge) =>
                nodes.some((node) => node.id === edge.source.id) &&
                nodes.some((node) => node.id === edge.target.id)
        );


        const linkForce = forceLink(edges)
            .id(d => d.id)
            .strength(0.2)
            .distance(d => {
                const sourceNode = nodes.find(n => n.id === d.source.id);
                const targetNode = nodes.find(n => n.id === d.target.id);
                return Math.sqrt(
                    Math.pow((sourceNode.width + targetNode.width) / 2, 2) +
                    Math.pow((sourceNode.height + targetNode.height) / 2, 2)
                ) * 1;
            });

        const simulation = createSimulation();
        simulation.nodes(nodes)
            .force('link', linkForce)
            .force('collide', D3ForceCollide().strength(0.8));

        for (let i = 0; i < iterations; ++i) simulation.tick();

        setNodes(
            nodes.map((node) => {
                if (node.parentId && getNodes().find(n => n.id === node.parentId)) {
                    return node;
                }

                delete node.width
                delete node.height

                return {
                    ...node,
                    position: { x: node.x, y: node.y }
                }
            })
        );

        setTimeout(fitView, 0);
    }, [getNodes, setNodes, getEdges, fitView]);

    return { getLayoutedElementsD3: getLayoutedElements };
};