/**
 * Describes a tree structure using a nested set model.
 * See https://en.wikipedia.org/wiki/Nested_set_model
 */

import { UniqueIdentifier } from "@dnd-kit/core"
import { TreeNode } from "Types"
import {
  compact,
  drop,
  groupBy,
  isEqual,
  last,
  map,
  max,
  range,
  sortBy,
  take,
  uniqueId,
  zip,
} from "lodash"
import {
  useCallback,
  useEffect,
  useMemo,
  useReducer,
  useRef,
  useState,
} from "react"

// ts-prune-ignore-next used in test
export type Tree = TreeNode[]

type Action =
  | {
      type: "load"
      nodes: TreeNode[]
    }
  | {
      type: "insert"
      parentId: UniqueIdentifier | null
      before: UniqueIdentifier | null
      nodes: TreeNode[]
    }
  | {
      type: "remove"
      id: UniqueIdentifier
    }
  | {
      type: "move"
      id: UniqueIdentifier
      parentId: UniqueIdentifier | null
      before: UniqueIdentifier | null
    }
  | {
      type: "indent"
      id: UniqueIdentifier
    }
  | {
      type: "outdent"
      id: UniqueIdentifier
    }
  | {
      type: "rename"
      id: UniqueIdentifier
      label: string
    }

type ActionParameters<A extends Action["type"]> = Omit<
  Extract<Action, { type: A }>,
  "type"
>

const findNode = (tree: Tree, id: UniqueIdentifier | null) =>
  (id && tree.find((node) => node.id === id)) || null

const renumber = (tree: Tree, left = 1, depth = 0) => {
  if (!tree.length) return tree
  const p = left - tree[0].left
  const d = depth - tree[0].depth
  return tree.map((node) => ({
    ...node,
    left: node.left + p,
    right: node.right + p,
    depth: node.depth + d,
  }))
}

export const buildTree = (
  nodes: Omit<TreeNode, "left" | "right" | "depth">[],
  roots = nodes.filter((node) => !node.parent_id),
  left = 1,
  depth = 0
) => {
  const nodesByParent_id = groupBy(nodes, "parent_id")
  let result: Tree = []
  let p = left
  for (const node of roots) {
    const children = nodesByParent_id[node.id] || []
    const nextLevel = buildTree(nodes, children, p + 1, depth + 1)
    const right = (max(map(nextLevel, "right")) ?? p) + 1
    result = [
      ...result,
      {
        ...node,
        left: p,
        right,
        depth,
      },
      ...nextLevel,
    ]
    p = right + 1
  }
  return result
}

// ts-prune-ignore-next used in test
export const insert = (
  tree: Tree,
  { nodes, parentId, before: beforeId }: ActionParameters<"insert">
): Tree => {
  if (!nodes.length) return tree
  const oldParentId = nodes[0].parent_id
  const oldDepth = nodes[0].depth

  const nextRootIndex = nodes.findIndex(
    (node, index) => index > 0 && node.depth === oldDepth
  )
  const nodesToInsert =
    nextRootIndex >= 0 ? nodes.slice(0, nextRootIndex) : nodes

  const parent = (parentId && findNode(tree, parentId)) || null
  const index = beforeId
    ? tree.findIndex((node) => node.id === beforeId)
    : tree.length
  const before = tree[index] || null
  const width = nodesToInsert[0].right - nodesToInsert[0].left + 1

  const left =
    before?.parent_id === parentId
      ? before.left
      : parent
        ? parent.right
        : (last(tree.filter((n) => !n.parent_id))?.right ?? 0) + 1
  const inserted = renumber(
    nodesToInsert,
    left,
    parent ? parent.depth + 1 : 0
  ).map((node) => ({
    ...node,
    parent_id: node.parent_id === oldParentId ? parentId : node.parent_id,
  }))

  const displaced = tree.map((node) => ({
    ...node,
    left: node.left >= left ? node.left + width : node.left,
    right: node.right >= left ? node.right + width : node.right,
  }))

  const newNodes = [
    ...displaced.slice(0, index),
    ...inserted,
    ...displaced.slice(index),
  ]
  if (nextRootIndex >= 0) {
    return insert(newNodes, {
      nodes: nodes.slice(nextRootIndex),
      parentId,
      before: beforeId,
    })
  } else {
    return newNodes
  }
}

// ts-prune-ignore-next used in test
export const remove = (tree: Tree, { id }: ActionParameters<"remove">) => {
  const nodeToRemove = findNode(tree, id)

  if (!nodeToRemove) return tree

  const width = nodeToRemove.right - nodeToRemove.left + 1

  return tree
    .filter(
      (node) => node.left < nodeToRemove.left || node.right > nodeToRemove.right
    )
    .map((node) => ({
      ...node,
      left: node.left > nodeToRemove.left ? node.left - width : node.left,
      right: node.right > nodeToRemove.right ? node.right - width : node.right,
    }))
}

// ts-prune-ignore-next used in test
export const move = (
  tree: Tree,
  { id, parentId, before }: ActionParameters<"move">
) => {
  const node = findNode(tree, id)
  if (!node) return tree
  const nodes = tree.filter((n) => n.left >= node.left && n.right <= node.right)
  return insert(remove(tree, { id }), {
    nodes,
    parentId,
    before,
  })
}

const maxDepth = (previous: TreeNode | null) =>
  previous ? previous.depth + 1 : 0

// ts-prune-ignore-next used in test
export const indent = (tree: Tree, { id }: ActionParameters<"indent">) => {
  const index = tree.findIndex((n) => n.id === id)
  if (index < 0) return tree
  const max = maxDepth(tree[index - 1] ?? null)
  const node = tree[index]
  if (node.depth >= max) return tree
  const newParent = tree
    .slice(0, index)
    .reverse()
    .find((n) => n.depth === node.depth)
  if (!newParent) return tree
  return tree.map((n) =>
    n.id === newParent.id
      ? { ...n, right: node.right }
      : n.left >= node.left && n.right <= node.right
        ? {
            ...n,
            parent_id: n.id === node.id ? newParent.id : n.parent_id,
            depth: n.depth + 1,
            left: n.left - 1,
            right: n.right - 1,
          }
        : n
  )
}

// ts-prune-ignore-next used in test
export const outdent = (tree: Tree, { id }: ActionParameters<"outdent">) => {
  const index = tree.findIndex((n) => n.id === id)
  if (index < 0) return tree
  const node = tree[index]
  if (node.depth <= 0) return tree
  const newParent = tree
    .slice(0, index)
    .reverse()
    .find((n) => n.depth === node.depth - 2)

  const before = tree.slice(index + 1).find((n) => n.depth <= node.depth - 1)

  return move(tree, {
    id,
    parentId: newParent?.id ?? null,
    before: before?.id ?? null,
  })
}

const rename = (tree: Tree, { id, label }: ActionParameters<"rename">) =>
  tree.map((n) => (n.id === id ? { ...n, label } : n))

type ActionReducer<A extends Action> = (state: Tree, action: A) => Tree

const ACTIONS: {
  [K in Action["type"]]: ActionReducer<Extract<Action, { type: K }>>
} = {
  insert,
  remove,
  move,
  indent,
  outdent,
  rename,
  load: (_, { nodes }) => buildTree(nodes),
} as const

// ts-prune-ignore-next used in test
export const flatten = (tree: Tree, expandedIds: Set<UniqueIdentifier>) => {
  const nodesByParentId = tree.reduce(
    (acc, node) =>
      acc.set(node.parent_id, [...(acc.get(node.parent_id) || []), node]),
    new Map<UniqueIdentifier | null, TreeNode[]>()
  )

  const flattenNodes = (nodes: TreeNode[]): TreeNode[] =>
    nodes.flatMap((node) => [
      node,
      ...(expandedIds.has(node.id)
        ? flattenNodes(nodesByParentId.get(node.id) || [])
        : []),
    ])

  return flattenNodes(nodesByParentId.get(null) || [])
}

export type UseTreeOptions = {
  readOnly?: boolean
  tree: Tree
  correctNodeIds: Set<UniqueIdentifier>
  onChange?: (tree: Tree) => void
  onSelect?: (node: TreeNode) => void
  onDeselect?: (node: TreeNode) => void
}

const checkTree = (tree: Tree) => {
  const sorted = sortBy(tree, "left")
  if (!isEqual(sorted, tree)) return "Tree nodes are out of order"
  const edges = tree
    .flatMap((node) => [node.left, node.right])
    .sort((a, b) => a - b)
  if (!isEqual(edges, range(1, tree.length * 2 + 1)))
    return "Tree edges are invalid"

  const wellNested = (nodes: TreeNode[]): boolean =>
    nodes.every((node) => {
      const children = tree.filter((n) => n.parent_id === node.id)
      if (children.length === 0) return node.right === node.left + 1
      return (
        wellNested(children) &&
        node.left === children[0].left - 1 &&
        node.right === (last(children)?.right || 0) + 1
      )
    }) &&
    zip(take(nodes, nodes.length - 1), drop(nodes, 1)).every(
      ([previous, next]) => next!.left === previous!.right + 1
    )

  if (!wellNested(tree.filter((n) => !n.parent_id))) {
    return "Tree is not nested properly"
  }
}

export const useTree = ({
  tree: initialTree,
  correctNodeIds,
  onChange,
  onSelect,
  onDeselect,
}: UseTreeOptions) => {
  const [expanded, setExpanded] = useState<Set<UniqueIdentifier>>(
    () => new Set(initialTree.filter((n) => !n.parent_id).map((n) => n.id))
  )

  const lastChangedBy = useRef<"external" | "internal">("external")

  const [paperTrail, setPaperTrail] = useState<[Action, Tree][]>([])

  useEffect(() => {
    if (!paperTrail.length) return
    const t = last(paperTrail)![1]
    const error = checkTree(t)
    if (error) {
      console.error(error, JSON.stringify(paperTrail))
    }
  }, [paperTrail])

  const [tree, dispatch] = useReducer(
    <A extends Action>(state: Tree, action: A) => {
      lastChangedBy.current = "internal"
      setPaperTrail((trail) => [
        ...trail,
        [action, (ACTIONS[action.type] as ActionReducer<A>)(state, action)],
      ])
      return (ACTIONS[action.type] as ActionReducer<A>)(state, action)
    },
    initialTree
  )

  useEffect(() => {
    lastChangedBy.current = "external"
  }, [initialTree])

  useEffect(() => {
    if (onChange && tree !== initialTree) {
      if (lastChangedBy.current === "internal") {
        onChange(tree)
      } else {
        const map = new Map<UniqueIdentifier, UniqueIdentifier>(
          compact(
            zip(tree, initialTree).map(([oldNode, newNode]) =>
              oldNode ? [oldNode.id, newNode!.id] : null
            )
          )
        )
        setExpanded(
          (expanded) => new Set([...expanded].map((id) => map.get(id) || id))
        )
        dispatch({ type: "load", nodes: initialTree })
      }
    }
  }, [tree, initialTree])

  const isSelected = useCallback(
    (node: TreeNode) => correctNodeIds.has(node.id),
    [correctNodeIds]
  )

  const flattened = useMemo(() => flatten(tree, expanded), [tree, expanded])

  const insert = useCallback(
    (action: ActionParameters<"insert">) =>
      dispatch({ type: "insert", ...action }),
    [dispatch]
  )

  const remove = useCallback(
    (action: ActionParameters<"remove">) =>
      dispatch({ type: "remove", ...action }),
    [dispatch]
  )

  const collapse = useCallback(
    ({ id }: { id: UniqueIdentifier }) => {
      const node = findNode(tree, id)
      if (!node) return

      setExpanded((expanded) => {
        const newSet = new Set(expanded)
        tree.forEach((n) => {
          if (n.left >= node.left && n.right <= node.right) newSet.delete(n.id)
        })
        return newSet
      })
    },
    [tree]
  )

  const expand = useCallback(
    ({ id }: { id: UniqueIdentifier }) => {
      const node = findNode(tree, id)
      if (!node) return
      setExpanded((expanded) =>
        tree
          .filter((n) => n.left <= node.left && n.right >= node.right)
          .reduce((set, n) => set.add(n.id), new Set(expanded))
      )
    },
    [tree]
  )

  const move = useCallback(
    (action: ActionParameters<"move">) => {
      dispatch({ type: "move", ...action })
      if (action.parentId) expand({ id: action.parentId })
    },
    [dispatch, expand]
  )

  const indent = useCallback(
    (action: ActionParameters<"indent">) =>
      dispatch({ type: "indent", ...action }),
    [dispatch]
  )

  const outdent = useCallback(
    (action: ActionParameters<"outdent">) =>
      dispatch({ type: "outdent", ...action }),
    [dispatch]
  )

  const rename = useCallback(
    (action: ActionParameters<"rename">) =>
      dispatch({ type: "rename", ...action }),
    [dispatch]
  )

  const path = useCallback(
    (node: TreeNode) => {
      const ancestors = tree.filter(
        (n) => n.left <= node.left && n.right >= node.right
      )
      return ancestors.map((n) => n.id.toString())
    },
    [tree]
  )

  const select = useCallback(
    ({ id }: { id: UniqueIdentifier }) => {
      if (!onSelect) return
      const node = findNode(tree, id)
      if (node) onSelect(node)
    },
    [onSelect, tree]
  )

  const deselect = useCallback(
    ({ id }: { id: UniqueIdentifier }) => {
      if (!onDeselect) return
      const node = findNode(tree, id)
      if (node) onDeselect(node)
    },
    [onDeselect, tree]
  )

  const load = useCallback(
    (nodes: Tree) => {
      dispatch({ type: "load", nodes })
      setExpanded(new Set(nodes.filter((n) => !n.parent_id).map((n) => n.id)))
    },
    [dispatch]
  )

  const [editing, setEditing] = useState<UniqueIdentifier | null>(null)

  const newNodeId = () => uniqueId("new_")

  const addSiblingOf = useCallback(
    ({ id }: { id: UniqueIdentifier }) => {
      const node = findNode(tree, id)
      if (!node) return

      const index = tree.findIndex((n) => n.id === node.id)
      const before =
        tree.slice(index + 1).find((n) => n.depth <= node.depth)?.id || null

      const sibling = {
        id: newNodeId(),
        label: "",
        parent_id: null,
        depth: 0,
        left: 1,
        right: 2,
      }
      setEditing(sibling.id)
      if (node.parent_id) expand({ id: node.parent_id })
      insert({
        nodes: [sibling],
        parentId: node.parent_id,
        before,
      })
    },
    [tree, insert]
  )

  const addNewRoot = useCallback(() => {
    const newNode = {
      id: newNodeId(),
      label: "",
      parent_id: null,
      depth: 0,
      left: 1,
      right: 2,
    }
    setEditing(newNode.id)
    insert({ nodes: [newNode], parentId: null, before: null })
  }, [tree, insert])

  const addChildOf = useCallback(
    ({ id }: { id: UniqueIdentifier }) => {
      const parent = findNode(tree, id)
      if (!parent) return

      const index = tree.findIndex((n) => n.id === parent.id)
      const before =
        tree.slice(index + 1).find((n) => n.depth <= parent.depth)?.id || null

      const node = {
        id: newNodeId(),
        label: "",
        parent_id: id,
        depth: 0,
        left: 1,
        right: 2,
      }
      expand(parent)
      setEditing(node.id)
      insert({ nodes: [node], parentId: id, before })
    },
    [tree, insert]
  )

  const nodesWithSelectedChildren = useMemo(() => {
    const nodesById = new Map<UniqueIdentifier, TreeNode>(
      tree.map((node) => [node.id, node])
    )
    const nodeAndAncestors = (node?: TreeNode): UniqueIdentifier[] =>
      node
        ? node.parent_id
          ? [...nodeAndAncestors(nodesById.get(node.parent_id)), node.id]
          : [node.id]
        : []
    return new Set(
      tree.flatMap((node) => (isSelected(node) ? nodeAndAncestors(node) : []))
    )
  }, [tree, isSelected])

  const hasSelectedChildren = useCallback(
    ({ id }: { id: UniqueIdentifier }) => nodesWithSelectedChildren.has(id),
    [nodesWithSelectedChildren]
  )

  const isExpanded = useCallback(
    ({ id }: { id: UniqueIdentifier }) => expanded.has(id),
    [expanded]
  )

  const expandAll = useCallback(() => {
    setExpanded(new Set(tree.map((n) => n.id)))
  }, [tree])

  const collapseAll = useCallback(() => {
    setExpanded(new Set())
  }, [])

  return {
    tree,
    flattened,
    insert,
    remove,
    collapse,
    expand,
    move,
    indent,
    outdent,
    rename,
    select,
    deselect,
    load,
    addSiblingOf,
    addChildOf,
    addNewRoot,
    editing,
    setEditing,
    isSelected,
    isExpanded,
    hasSelectedChildren,
    path,
    correctNodeIds,
    expandAll,
    collapseAll,
  }
}
