import { Button, ButtonGroup, Stack } from "@chakra-ui/react"
import {
  DndContext,
  DragMoveEvent,
  DragOverEvent,
  DragOverlay,
  DragStartEvent,
  DropAnimation,
  MeasuringStrategy,
  UniqueIdentifier,
  closestCenter,
  defaultDropAnimation,
} from "@dnd-kit/core"
import {
  SortableContext,
  arrayMove,
  verticalListSortingStrategy,
} from "@dnd-kit/sortable"
import { CSS } from "@dnd-kit/utilities"
import { AddIcon } from "Icons/AddIcon"
import { TreeNode } from "Types"
import { clamp, map } from "lodash"
import React, { CSSProperties, useMemo, useState } from "react"
import { createPortal } from "react-dom"
import { Item } from "./Item"
import { useTreeContext } from "./TreeProvider"

type TreeProps = {
  readOnly?: boolean
}

const indent = 62

const measuring = {
  droppable: {
    strategy: MeasuringStrategy.Always,
  },
}

const dropAnimationConfig: DropAnimation = {
  keyframes({ transform }) {
    return [
      { opacity: 1, transform: CSS.Transform.toString(transform.initial) },
      {
        opacity: 0,
        transform: CSS.Transform.toString({
          ...transform.final,
          x: transform.final.x + 5,
          y: transform.final.y + 5,
        }),
      },
    ]
  },
  easing: "ease-out",
  sideEffects({ active }) {
    active.node.animate([{ opacity: 0 }, { opacity: 1 }], {
      duration: defaultDropAnimation.duration,
      easing: defaultDropAnimation.easing,
    })
  },
}

export const Tree: React.FC<TreeProps> = ({ readOnly = false }) => {
  const { flattened, move, collapse, addNewRoot } = useTreeContext()

  const [activeId, setActiveId] = useState<UniqueIdentifier | null>(null)
  const [activeNode, setActiveNode] = useState<TreeNode | null>(null)
  const [overId, setOverId] = useState<UniqueIdentifier | null>(null)
  const [offsetLeft, setOffsetLeft] = useState(0)

  const projected = useMemo(
    () =>
      activeId && overId
        ? getProjection(flattened, activeId, overId, offsetLeft, indent)
        : null,
    [activeId, overId, flattened, offsetLeft]
  )

  const sortedIds = useMemo(() => map(flattened, "id"), [flattened])

  const handleDragStart = ({ active: { id: activeId } }: DragStartEvent) => {
    setActiveId(activeId as UniqueIdentifier)
    setOverId(activeId as UniqueIdentifier)

    const activeNode = flattened.find(({ id }) => id === activeId) ?? null
    setActiveNode(activeNode)

    if (activeNode) {
      collapse(activeNode)
    }

    document.body.style.setProperty("cursor", "grabbing")
  }

  const handleDragMove = ({ delta }: DragMoveEvent) => {
    setOffsetLeft(delta.x)
  }

  const handleDragOver = ({ over }: DragOverEvent) => {
    setOverId((over?.id ?? null) as UniqueIdentifier | null)
  }

  const handleDragEnd = ({ active, over }: DragOverEvent) => {
    resetState()

    if (projected && over) {
      const { parentId } = projected
      const activeIndex = flattened.findIndex((node) => node.id === active.id)
      const overIndex = flattened.findIndex((node) => node.id === over.id)
      const before =
        flattened[overIndex + (activeIndex <= overIndex ? 1 : 0)]?.id ?? null
      move({
        id: active.id,
        parentId,
        before,
      })
    }
  }

  const resetState = () => {
    setOverId(null)
    setActiveId(null)
    setActiveNode(null)
    setOffsetLeft(0)

    document.body.style.removeProperty("cursor")
  }

  return (
    <Stack rounded={6} bg="bg.page">
      <DndContext
        collisionDetection={closestCenter}
        measuring={measuring}
        onDragStart={handleDragStart}
        onDragMove={handleDragMove}
        onDragOver={handleDragOver}
        onDragEnd={handleDragEnd}
      >
        <Stack
          gap={2}
          p={6}
          rounded={6}
          style={{ "--indent": `${indent}px` } as CSSProperties}
          role="tree"
          overflowX="auto"
        >
          <SortableContext
            items={sortedIds}
            strategy={verticalListSortingStrategy}
          >
            {flattened.map((node) => (
              <Item
                key={node.id}
                node={node}
                readOnly={readOnly}
                depth={
                  (node.id === activeId ? projected?.depth : null) ?? undefined
                }
              />
            ))}
            {createPortal(
              <DragOverlay dropAnimation={dropAnimationConfig}>
                {activeNode ? <Item node={activeNode} clone /> : null}
              </DragOverlay>,
              document.body
            )}
          </SortableContext>
        </Stack>
      </DndContext>
      {!readOnly && (
        <ButtonGroup size="sm" p={6} pt={0}>
          <Button
            variant="solid"
            colorScheme="brand.primary"
            leftIcon={<AddIcon />}
            onClick={addNewRoot}
          >
            Add node
          </Button>
        </ButtonGroup>
      )}
    </Stack>
  )
}

const getProjection = (
  nodes: TreeNode[],
  activeId: UniqueIdentifier,
  overId: UniqueIdentifier,
  dragOffset: number,
  indentationWidth: number
) => {
  const overNodeIndex = nodes.findIndex(({ id }) => id === overId)
  const activeNodeIndex = nodes.findIndex(({ id }) => id === activeId)
  const activeNode = nodes[activeNodeIndex]
  const newNodes = arrayMove(nodes, activeNodeIndex, overNodeIndex)
  const previousNode = newNodes[overNodeIndex - 1]
  const nextNode = newNodes[overNodeIndex + 1]
  const dragDepth = Math.round(dragOffset / indentationWidth)
  const projectedDepth = activeNode.depth + dragDepth
  const minDepth = nextNode ? nextNode.depth : 0
  const maxDepth = previousNode ? previousNode.depth + 1 : 0
  const depth = clamp(projectedDepth, minDepth, maxDepth)

  const getParentId = () => {
    if (depth === 0 || !previousNode) return null
    if (depth === previousNode.depth) return previousNode.parent_id
    if (depth > previousNode.depth) return previousNode.id

    return (
      newNodes
        .slice(0, overNodeIndex)
        .reverse()
        .find((node) => node.depth === depth && node.id !== activeId)
        ?.parent_id ?? null
    )
  }

  return { depth, maxDepth, minDepth, parentId: getParentId() }
}
