import { TreeNode } from "Types"
import { useSectionContext } from "UsabilityHub/contexts"
import { last, map, orderBy, range, sortBy, uniq } from "lodash"
import {
  MutableRefObject,
  useCallback,
  useMemo,
  useReducer,
  useRef,
} from "react"
import { CommonPath, useGroupedPaths } from "../usePaths"
import { MERGE_RESULTS } from "./Legend"

export type PathDiagramResult =
  | "direct_success"
  | "direct_failure"
  | "direct_skipped"
  | "indirect_success"
  | "indirect_failure"
  | "indirect_skipped"

export type LaidOutNode = {
  id: string
  parentId: string | null
  node: TreeNode | null
  expanded: boolean
  value: number
  children: LaidOutNode[]
  links: LaidOutLink[]
  x: number
  y: number
  depth: number
  top: number
  height: number
  result: PathDiagramResult
}

export type LaidOutLink = {
  sourceId: string
  targetId: string
  value: number
  result: PathDiagramResult
  x0: number
  x1: number
  y0: number
  y1: number
  width: number
}

type ExpandAction = { type: "expand"; ids: string[]; deep?: boolean }
type CollapseAction = { type: "collapse"; id: string }

const HORIZONTAL_GAP = 200
const VERTICAL_SCALE = 4
const VERTICAL_GAP = 12
const MIN_NODE_HEIGHT = 24
const MARGIN = 24

export const useTreeTestPathDiagramLayout = () => {
  const width = useRef(0)
  const height = useRef(0)

  const groupedPaths = useGroupedPaths()

  const { section } = useSectionContext()

  const correctNodeIds = useMemo(
    () =>
      new Set((section.tree_test_attributes?.correct_nodes || []).map(String)),
    [section]
  )

  const allNodeIds = useMemo(
    () =>
      uniq(
        groupedPaths.flatMap(({ nodes }) =>
          nodes.map((_node, i) =>
            nodes
              .slice(0, i + 1)
              .map((n) => n?.id ?? "skip")
              .join(":")
          )
        )
      ),
    [groupedPaths]
  )

  const [expandedNodeIds, dispatch] = useReducer(
    (state: Set<string>, action: ExpandAction | CollapseAction) => {
      switch (action.type) {
        case "expand":
          return new Set([...state, ...action.ids])
        case "collapse":
          return new Set(
            [...state].filter(
              (id) => id !== action.id && !id.startsWith(`${action.id}:`)
            )
          )
      }
    },
    groupedPaths,
    (paths: CommonPath[]) =>
      new Set(
        paths.flatMap((path) =>
          // Expand the first level by default
          range(1).map((i) =>
            path.nodes
              .slice(0, i + 1)
              .filter(Boolean)
              .map((node) => node?.id)
              .join(":")
          )
        )
      )
  )

  const expand = useCallback(
    (id: string, deep = false) => {
      // Holding down shift does a "deep" expand of the current node (all descendants)
      // These nodes are prefixed with the current node's id
      const ids = deep
        ? [id, ...allNodeIds.filter((x) => x.startsWith(`${id}:`))]
        : [id]
      dispatch({ type: "expand", ids: ids, deep })
    },
    [allNodeIds]
  )

  const collapse = useCallback(
    (id: string) => dispatch({ type: "collapse", id }),
    []
  )

  const roots = useMemo(() => {
    const rootNodes: LaidOutNode[] = []

    const nodeMap = new Map<string, LaidOutNode>()

    width.current = 0
    height.current = 0

    for (const path of orderBy(
      groupedPaths,
      [
        ({ result }) =>
          result === "success" ? 0 : result === "failure" ? 1 : 2,
        "directness",
        "participants",
      ],
      ["asc", "asc", "desc"]
    )) {
      nestPath({
        path,
        at: rootNodes,
        nodeMap,
        correctNodeIds,
      })
    }

    for (const id of expandedNodeIds) {
      const node = nodeMap.get(id)
      if (node) node.expanded = true
    }

    const sortedNodeIds: string[] = map(
      orderBy(nodeMap.values(), ["value"], ["desc"]),
      "id"
    )

    const sortedRoots = sortBy(rootNodes, ({ id }) => sortedNodeIds.indexOf(id))

    layOutNodes({
      nodes: sortedRoots,
      diagramWidth: width,
      diagramHeight: height,
      nodeMap,
      sortedNodeIds,
    })

    return sortedRoots
  }, [groupedPaths, expandedNodeIds])

  return {
    roots,
    expandedNodeIds,
    expand,
    collapse,
    width: width.current + MARGIN * 2,
    height: height.current + MARGIN * 2,
    horizontalGap: HORIZONTAL_GAP,
    margin: MARGIN,
  }
}

/**
 * Converts a CommonPath into a nested subtree.
 *
 * @param options Object containing the following keys:
 *   - `path`: The path to convert
 *   - `depth`: How far down the path we are
 *   - `at`: Because paths may share nodes and links,
 *     the `at` option is used to specify the location to
 *     place the new node in, rather than returning it
 *   - `nodeMap`: a convenient way to look up all the nodes
 *     we've built so far
 *   - `correctNodeIds`: Set of nominated correct answers
 */
// ts-prune-ignore-next used in test
export const nestPath = ({
  path,
  depth = 0,
  at,
  nodeMap,
  correctNodeIds,
}: {
  path: CommonPath
  depth?: number
  at: LaidOutNode[]
  nodeMap: Map<string, LaidOutNode>
  correctNodeIds: Set<string>
}) => {
  const nodeId = String(path.nodes[depth]?.id || "skip")
  const id = path.nodes
    .slice(0, depth + 1)
    .map((n) => n?.id ?? "skip")
    .join(":")
  let node = nodeMap.get(id)

  // Some path results are combined in this visualisation
  const result =
    MERGE_RESULTS[`${path.directness}_${path.result}` as PathDiagramResult]
  const correctNode =
    (correctNodeIds.has(String(nodeId)) && path.nodes[depth]) || null

  // Create a node if it doesn't exist already
  if (!node) {
    node = {
      id,
      parentId: id.replace(/:[^:]+$/, "") || null,
      node: path.nodes[depth],
      expanded: false,
      value: 0,
      children: [],
      links: [],
      x: 0,
      y: 0,
      depth,
      top: 0,
      height: 0,
      result: correctNode
        ? (`${
            correctNode.depth < depth ? "indirect" : "direct"
          }_success` as PathDiagramResult)
        : path.nodes[depth]
          ? "indirect_failure"
          : "indirect_skipped",
    }
    at.push(node)
    nodeMap.set(id, node)
  }

  node.value += path.participants

  if (depth < path.nodes.length - 1) {
    // Recurse the remaining steps on the path
    nestPath({
      path,
      depth: depth + 1,
      at: node.children,
      nodeMap,
      correctNodeIds,
    })

    // Find or create a link to the next node
    const next = path.nodes[depth + 1]
    const nextId = `${id}:${next?.id ?? "skip"}`
    let link = node.links.find(
      (l) => l.targetId === nextId && l.result === result
    )
    if (!link) {
      link = {
        sourceId: id,
        targetId: nextId,
        value: 0,
        result,
        x0: 0,
        x1: 0,
        y0: 0,
        y1: 0,
        width: 0,
      }
      node.links.push(link)
    }
    link.value += path.participants
  }
}

/**
 * Places nodes in their correct 2D layout
 * @param options Object containing the following keys:
 *   - nodes: List of `LaidOutNode`s to format
 *   - depth: Depth of the tree we’re at
 *   - diagramWidth: Ref to the maximum width of the diagram
 *   - diagramHeight: Ref to the maximum height of the diagram
 *   - columns: Used to keep track of vertical positions of nodes in each column
 *   - endpoints: Used to keep track of the vertical endpoints of links going into a node
 *   - nodeMap: Map for looking up nodes by id
 *   - sortedNodeIds: Preferred order for displaying nodes
 */
const layOutNodes = ({
  nodes,
  depth = 0,
  diagramWidth,
  diagramHeight,
  columns = [],
  endpoints = {},
  nodeMap,
  sortedNodeIds,
}: {
  nodes: LaidOutNode[]
  depth?: number
  diagramWidth: MutableRefObject<number>
  diagramHeight: MutableRefObject<number>
  columns?: LaidOutNode[][]
  endpoints?: Record<string, number>
  nodeMap: Map<string, LaidOutNode>
  sortedNodeIds: string[]
}) => {
  // Adjust width of the diagram to accommodate new nodes
  diagramWidth.current = Math.max(
    diagramWidth.current,
    (depth + 1) * HORIZONTAL_GAP
  )

  // Ensure we have enough columns in the layout
  while (columns.length <= depth) columns.push([])

  for (const node of nodes) {
    const parent = (node.parentId && nodeMap.get(node.parentId)) || null

    // Base the position on the previous node in the column, if any
    const previous = last(columns[depth])

    let y = previous
      ? previous.y +
        previous.height +
        VERTICAL_GAP * (previous.parentId === node.parentId ? 1 : 2)
      : 0

    // Ensure a node doesn't appear higher than its parent
    // This is not quite optimal use of space but it's more legible
    if (parent) y = Math.max(y, parent.y)

    node.height = Math.max(MIN_NODE_HEIGHT, node.value * VERTICAL_SCALE)
    node.x = node.depth * HORIZONTAL_GAP
    node.y = y
    endpoints[node.id] =
      node.y + (node.height - node.value * VERTICAL_SCALE) / 2

    node.children = sortBy(node.children, ({ id }) => sortedNodeIds.indexOf(id))

    if (node.expanded) {
      layOutNodes({
        nodes: node.children,
        depth: depth + 1,
        diagramWidth,
        diagramHeight,
        columns,
        endpoints,
        nodeMap,
        sortedNodeIds,
      })

      if (node.children.length) {
        node.y = Math.max(node.y, node.children[0].y)
        endpoints[node.id] =
          node.y + (node.height - node.value * VERTICAL_SCALE) / 2
      }

      node.links = sortBy(
        node.links,
        ({ targetId }) => nodeMap.get(targetId)?.y ?? 0
      )

      let linkY = node.y + (node.height - node.value * VERTICAL_SCALE) / 2

      // Links appear as stacked Bézier curves between two nodes.
      // We need to account for their thickness at both ends.
      for (const link of node.links) {
        const width = link.value * VERTICAL_SCALE
        link.x1 = HORIZONTAL_GAP
        link.y0 = linkY + width / 2
        link.y1 = endpoints[link.targetId] + width / 2
        link.width = width
        linkY += width
        endpoints[link.targetId] += width
      }
    }

    diagramHeight.current = Math.max(
      diagramHeight.current,
      node.y + node.height
    )

    columns[depth].push(node)
  }
}
