_ = require 'lodash'
d3 = require 'd3-array'

V = require './v'

# Disjoint-set linked list data structure with union by rank
# and path compression
# http://en.wikipedia.org/wiki/Disjoint-set_data_structure
UnionFind = do ->
  makeSet = (x) ->
    x._parent = x
    x._rank = 0

  find = (x) ->
    if x._parent == x
      x
    else
      x._parent = find(x._parent)

  union = (x, y) ->
    xRoot = find x
    yRoot = find y
    return if xRoot == yRoot
    switch
      when xRoot._rank < yRoot._rank
        xRoot._parent = yRoot
      when xRoot._rank > yRoot._rank
        yRoot._parent = xRoot
      else
        yRoot._parent = xRoot
        xRoot._rank += 1

  unmakeSet = (x) ->
    # Remove the properties added to the object by makeSet
    delete x._parent
    delete x._rank

  {makeSet, find, union, unmakeSet}

mstEdges = (G, sortfn = d3.ascending) ->
  # Computes the set of graph edges constituting a minimum spanning tree over G = {V, E} using
  # Kruskal's algorithm.
  # Change sortfn to d3.descending to compute a maximum spanning tree instead.
  edges = [ ]

  for v in G.V
    UnionFind.makeSet v

  G.E.sort (a, b) -> sortfn(a.weight, b.weight)

  for e in G.E
    if UnionFind.find(e.u) != UnionFind.find(e.v)
      edges.push e
      u = e.u
      v = e.v
      UnionFind.union u, v

  for v in G.V
    UnionFind.unmakeSet v

  edges

###
# rememberDominantClusters
# ({u, v, u_count, v_count}, leaders) -> leaders
# - u: one node in a breaking link
# - v: the other node in a breaking link
# - u_count: the number of nodes in the u cluster
# - v_count: the number of nodes in the v cluster
# - leaders: null or a list of up to two {u, v, u_count, v_count} entries
#
# This method updates the leaders list with a potential new entry to keep
# the two entries that maximize Math.min(u_count, v_count).
###
rememberDominantClusters = (entry, leaders) ->
  # Accept null for leaders
  leaders = leaders ? [{relevance: -1}, {relevance: -2}]

  # Compute the new relevance
  entry.relevance = Math.min(entry.u_count, entry.v_count)

  if entry.relevance > leaders[1].relevance
    if entry.relevance > leaders[0].relevance
      leaders[1] = leaders[0]
      leaders[0] = entry
    else
      leaders[1] = entry
  return leaders
###
  - sort by increasing edge weight
  - choose an arbitrary root
  - do the numbering thing
  - 4 things in an edge:
    - weight
    - endpoints
    - number range from each endpoint
  - then, order the endpoints such that  the second one is inside the first one
    (or vice versa, consistently)
  - when you break a link, take the list and break it in two (one per subtree)
    - take inner one
###
splitTree = (G) ->
  G.E.sort (a, b) -> d3.ascending(a.weight, b.weight)

  # Choose an arbitrary root
  root = G.V[0]
  unless G.E.length
    leader = {u: root, v: root, u_count: 1, v_count: 1, Eu: [], Ev: []}
    return [[root], [leader, leader]]

  # Number our nodes based on previsit and postvisit 'timestamps'
  index = 0
  dfs G, {
    pre: (v) -> v._preIndex = index++
    post: (v) -> v._postIndex = index++
  }, root

  # Ensure that in the tree rooted at our root, u contains v
  for e in G.E
    if e.u._preIndex > e.v._preIndex
      [e.u, e.v] = [e.v, e.u]

  order = [ ]

  # Break the weakest link and partition the remaining edges into
  # two groups based on the resulting subtrees
  # In the meantime: track the dominant clusters
  leaders = null
  lastVec = null

  rec = (E) ->
    weakest = E.shift()
    u = weakest.u
    v = weakest.v

    if E.length
      Eu = [ ]
      Ev = [ ]

      for e in E
        if v._preIndex <= e.u._preIndex and e.u._postIndex <= v._postIndex
          # This is in the subtree branching out from v
          Ev.push e
        else
          Eu.push e

      # Remember dominant clusters
      leaders = rememberDominantClusters {
        u: u
        v: v
        u_count: Eu.length + 1
        v_count: Ev.length + 1
        Eu: Eu
        Ev: Ev
      }, leaders


      swap = if lastVec?
        # Switch u and v depending on which was closer to the previous sublist
        dotU = V.dot(lastVec, u.galaxyVectors[0])
        dotV = V.dot(lastVec, v.galaxyVectors[0])
        dotV > dotU
      else
        # Switch u and v depending on the cluster sizes
        Eu.length > Ev.length

      if swap
        _tmp = v; v = u; u = _tmp
        _tmp = Ev; Ev = Eu; Eu = _tmp

      # Recurse on non-empty and non-singleton sublists

      if Eu.length
        if Eu.length == 1
          order.push Eu[0].u
          order.push Eu[0].v
        else
          rec Eu
      else
        order.push u
        lastVec = u.galaxyVectors[0]

      if Ev.length
        if Ev.length == 1
          order.push Ev[0].u
          order.push Ev[0].v
        else
          rec Ev
      else
        order.push v
        lastVec = v.galaxyVectors[0]


  rec G.E

  for v in G.V
    delete v._preIndex
    delete v._postIndex

  [order, leaders]


graphFromTerms = (terms) ->
  # Build a fully-connected graph of term-term association scores from the passed-in terms.
  G = {
    V: terms
    E: [ ]
  }

  for t1 in G.V
    for t2 in G.V
      G.E.push {u: t1, v: t2, weight: t1.getAssociation(t2)}

  G


dfs = (G, visitFns, root = G.V[0]) ->
  for v in G.V
    v._visited = false
    v._adjacent = [ ]

  for e in G.E
    e.u._adjacent.push e.v
    e.v._adjacent.push e.u unless e.u == e.v

  visit = (v) ->
    v._visited = true

    visitFns.pre?(v)
    for u in v._adjacent
      visit(u) unless u._visited
    visitFns.post?(v)

  visit root

  for v in G.V
    delete v._visited
    delete v._adjacent

adjacencyCluster = (G) ->
  # Single-link agglomerative clustering.
  E = mstEdges(G, d3.descending)

  [order, leaders] = splitTree {V: G.V, E: E}
  {order, leaders}

### Extract a list of vertices mentioned by a list of edges
###
verticesFromEdges = (edges) ->
  vertices = {}
  for edge in edges
    vertices[edge.u.exactTermIds[0]] = edge.u.galaxyVectors[0]
    vertices[edge.v.exactTermIds[0]] = edge.v.galaxyVectors[0]
  return vertices

clusterMean = (edges, init) ->
  # if there are no edges, return a copy of the sole vertex's vector
  return _.clone(init.galaxyVectors[0]) unless edges.length

  # otherwise, calculate and return the avg vector
  mean = new Float32Array(init.galaxyVectors[0].length)
  mean.fill(0)
  count = 0
  for term, vec of verticesFromEdges(edges)
    V.add mean, vec, mean
    count++

  V.scale(mean, 1.0 / count, mean)

deciderFromDominantCluster = (leader) ->
  u_mean = clusterMean leader.Eu, leader.u
  v_mean = clusterMean leader.Ev, leader.v
  V.subtract u_mean, v_mean
  V.normalize u_mean
  return u_mean

module.exports = {
  graphFromTerms
  adjacencyCluster
  deciderFromDominantCluster
}
