#' @import grid
#' @import circlize
#' @import ComplexHeatmap
#' @importFrom igraph graph V E layout_in_circle layout_on_sphere layout_randomly layout_with_fr layout_with_kk simplify
#' @importFrom ggpubr ggscatter
#' @import grDevices
#' 
NULL

#' Create a network heatmap
#'
#' Creates a heatmap of the signaling network. Alternatively, the network
#' matrix can be accessed directly in the signaling slot of a domino object using 
#' the [dom_signaling()] function.
#'
#' @param dom domino object with network built ([build_domino()])
#' @param clusts vector of clusters to be included. If NULL then all clusters are used.
#' @param min_thresh minimum signaling threshold for plotting. Defaults to -Inf for no threshold.
#' @param max_thresh maximum signaling threshold for plotting. Defaults to Inf for no threshold.
#' @param scale how to scale the values (after thresholding). Options are 'none', 'sqrt' for square root, or 'log' for log10.
#' @param normalize options to normalize the matrix. Normalization is done after thresholding and scaling. Accepted inputs are 'none' for no normalization, 'rec_norm' to normalize to the maximum value with each receptor cluster, or 'lig_norm' to normalize to the maximum value within each ligand cluster 
#' @param ... other parameters to pass to  [ComplexHeatmap::Heatmap()]
#' @return A heatmap rendered to the active graphics device
#' @export signaling_heatmap
#' @examples
#' example(build_domino, echo = FALSE)
#' #basic usage
#' signaling_heatmap(pbmc_dom_built_tiny)
#' #scale
#' signaling_heatmap(pbmc_dom_built_tiny, scale = "sqrt")
#' #normalize
#' signaling_heatmap(pbmc_dom_built_tiny, normalize = "rec_norm")
#'
signaling_heatmap <- function(
    dom, clusts = NULL, min_thresh = -Inf, max_thresh = Inf, scale = "none",
    normalize = "none", ...) {
  if (!dom@misc[["build"]]) {
    stop("Please run domino_build prior to generate signaling network.")
  }
  if (!length(dom@clusters)) {
    stop("This domino object wasn't built with clusters so intercluster signaling cannot be generated.")
  }
  mat <- dom@signaling

  if (!is.null(clusts)) {
    mat <- mat[paste0("R_", clusts), paste0("L_", clusts)]
  }
  mat[which(mat > max_thresh)] <- max_thresh
  mat[which(mat < min_thresh)] <- min_thresh
  if (scale == "sqrt") {
    mat <- sqrt(mat)
  } else if (scale == "log") {
    mat <- log10(mat)
  } else if (scale != "none") {
    stop("Do not recognize scale input")
  }
  if (normalize == "rec_norm") {
    mat <- do_norm(mat, "row")
  } else if (normalize == "lig_norm") {
    mat <- do_norm(mat, "col")
  } else if (normalize != "none") {
    stop("Do not recognize normalize input")
  }
  if (any(is.na(mat))) { 
    warning("Some values are NA, replacing with 0s.")
    mat[is.na(mat)] <- 0
  }

  Heatmap(
    mat,
    name = "collective\nsignaling",
    ...
  )
}

#' Create a cluster incoming signaling heatmap
#'
#' Creates a heatmap of a cluster incoming signaling matrix. Each cluster has a
#' list of ligands capable of activating its enriched transcription factors. The
#' function creates a heatmap of cluster average expression for all of those
#' ligands. A list of all cluster incoming signaling matrices can be found in
#' the cl_signaling_matrices slot of a domino option as an alternative to this
#' plotting function.
#'
#' @param dom Domino object with network built ([build_domino()])
#' @param rec_clust Which cluster to select as the receptor. Must match naming of clusters in the domino object.
#' @param clusts Vector of clusters to be included. If NULL then all clusters are used.
#' @param min_thresh Minimum signaling threshold for plotting. Defaults to -Inf for no threshold.
#' @param max_thresh Maximum signaling threshold for plotting. Defaults to Inf for no threshold.
#' @param scale How to scale the values (after thresholding). Options are 'none', 'sqrt' for square root, or 'log' for log10.
#' @param normalize Options to normalize the matrix. Accepted inputs are 'none' for no normalization, 'rec_norm' to normalize to the maximum value with each receptor cluster, or 'lig_norm' to normalize to the maximum value within each ligand cluster 
#' @param title Either a string to use as the title or a boolean describing whether to include a title. In order to pass the 'main' parameter to  [ComplexHeatmap::Heatmap()]  you must set title to FALSE.
#' @param ... Other parameters to pass to  [ComplexHeatmap::Heatmap()]. Note that to use the 'column_title' parameter of  [ComplexHeatmap::Heatmap()]  you must set title = FALSE
#' @return a Heatmap rendered to the active graphics device
#' @export incoming_signaling_heatmap
#' @examples
#' example(build_domino, echo = FALSE)
#' #incoming signaling of the CD8  T cells
#' incoming_signaling_heatmap(pbmc_dom_built_tiny, "CD8_T_cell")
#'
incoming_signaling_heatmap <- function(
    dom, rec_clust, clusts = NULL, min_thresh = -Inf, max_thresh = Inf,
    scale = "none", normalize = "none", title = TRUE, ...) {
  if (!dom@misc[["build"]]) {
    stop("Please run domino_build prior to generate signaling network.")
  }
  if (!length(dom@clusters)) {
    stop("This domino object wasn't build with clusters so cluster specific expression is not possible.")
  }
  mat <- dom@cl_signaling_matrices[[rec_clust]]
  if (dim(mat)[1] == 0) {
    message("No signaling found for this cluster under build parameters.")
    return()
  }

  if (!is.null(clusts)) {
    mat <- mat[, paste0("L_", clusts), drop = FALSE]
  }
  mat[which(mat > max_thresh)] <- max_thresh
  mat[which(mat < min_thresh)] <- min_thresh
  if (scale == "sqrt") {
    mat <- sqrt(mat)
  } else if (scale == "log") {
    mat <- log10(mat)
  } else if (scale != "none") {
    stop("Do not recognize scale input")
  }
  if (normalize == "rec_norm") {
    if (ncol(mat) > 1) {
      mat <- do_norm(mat, "row")
    }
  } else if (normalize == "lig_norm") {
    if (nrow(mat) > 1) {
      mat <- do_norm(mat, "col")
    }
  } else if (normalize != "none") {
    stop("Do not recognize normalize input")
  }

  if (any(is.na(mat))) { 
    warning("Some values are NA, replacing with 0s.")
    mat[is.na(mat)] <- 0
  }

  if (title == TRUE) {
    return(
      Heatmap(
        mat, 
        name = "expression",
        column_title = paste0("Expression of ligands targeting cluster ", rec_clust), 
        ...
      )
    )
  } else if (title == FALSE) {
    return(
      Heatmap(
        mat, 
        name = "expression",
        ...
      )
    )
  } else {
    return(
      Heatmap(
        mat,
        name = "expression",
        column_title = title, 
        ...
      )
    )
  }
}

#' Create a cluster to cluster signaling network diagram
#'
#' Creates a network diagram of signaling between clusters. Nodes are clusters
#' and directed edges indicate signaling from one cluster to another. Edges are
#' colored based on the color scheme of the ligand expressing cluster
#'
#' @param dom a domino object with network built ([build_domino()])
#' @param cols named vector indicating the colors for clusters. Values are colors and names must match clusters in the domino object. If left as NULL then ggplot colors are generated for the clusters
#' @param edge_weight weight for determining thickness of edges on plot. Signaling values are multiplied by this value
#' @param clusts vector of clusters to be included in the network plot
#' @param showOutgoingSignalingClusts vector of clusters to plot the outgoing signaling from
#' @param showIncomingSignalingClusts vector of clusters to plot the incoming signaling on
#' @param min_thresh minimum signaling threshold. Values lower than the threshold will be set to the threshold. Defaults to -Inf for no threshold
#' @param max_thresh maximum signaling threshold for plotting. Values higher than the threshold will be set to the threshold. Defaults to Inf for no threshold
#' @param normalize options to normalize the signaling matrix. Accepted inputs are 'none' for no normalization, 'rec_norm' to normalize to the maximum value with each receptor cluster, or 'lig_norm' to normalize to the maximum value within each ligand cluster
#' @param scale how to scale the values (after thresholding). Options are 'none', 'sqrt' for square root, 'log' for log10, or 'sq' for square
#' @param layout type of layout to use. Options are 'random', 'sphere', 'circle', 'fr' for Fruchterman-Reingold force directed layout, and 'kk' for Kamada Kawai for directed layout
#' @param scale_by how to size vertices. Options are 'lig_sig' for summed outgoing signaling, 'rec_sig' for summed incoming signaling, and 'none'. In the former two cases the values are scaled with asinh after summing all incoming or outgoing signaling
#' @param vert_scale integer used to scale size of vertices with our without variable scaling from size_verts_by.
#' @param plot_title text for the plot's title.
#' @param ... other parameters to be passed to plot when used with an igraph object.
#' @return An igraph plot rendered to the active graphics device
#' @export signaling_network
#' @examples 
#' example(build_domino, echo = FALSE)
#' #basic usage
#' signaling_network(pbmc_dom_built_tiny, edge_weight = 2)
#' # scaling, thresholds, layouts, selecting clusters
#' signaling_network(
#'  pbmc_dom_built_tiny, showOutgoingSignalingClusts = "CD14_monocyte", 
#'  scale = "none", norm = "none", layout = "fr", scale_by = "none", 
#'  vert_scale = 5, edge_weight = 2)
#' 
signaling_network <- function(
    dom, cols = NULL, edge_weight = 0.3, clusts = NULL, showOutgoingSignalingClusts = NULL,
    showIncomingSignalingClusts = NULL, min_thresh = -Inf, max_thresh = Inf, normalize = "none", scale = "sq",
    layout = "circle", scale_by = "rec_sig", vert_scale = 3, plot_title = NULL, ...) {
  if (!length(dom@clusters)) {
    stop("This domino object was not built with clusters so there is no intercluster signaling.")
  }
  if (!dom@misc[["build"]]) {
    stop("Please build a signaling network with domino_build prior to plotting.")
  }
  # Get signaling matrix
  mat <- dom@signaling

  if (any(is.na(mat))) { 
    warning("Some values are NA, replacing with 0s.")
    mat[is.na(mat)] <- 0
  }

  if (!is.null(clusts)) {
    mat <- mat[paste0("R_", clusts), paste0("L_", clusts), drop = FALSE]
  }
  if (!is.null(showOutgoingSignalingClusts)) {
    mat <- mat[, paste0("L_", showOutgoingSignalingClusts), drop = FALSE]
  }
  if (!is.null(showIncomingSignalingClusts)) {
    mat <- mat[paste0("R_", showIncomingSignalingClusts), , drop = FALSE]
  }
  if (sum(mat > 0) == 0) {
    warning("No signaling found")
    return(NULL)
  }
  if (is.null(cols)) {
    cols <- ggplot_col_gen(length(levels(dom@clusters)))
    names(cols) <- levels(dom@clusters)
  }

  mat[which(mat > max_thresh)] <- max_thresh
  mat[which(mat < min_thresh)] <- min_thresh
  if (scale == "sqrt") {
    mat <- sqrt(mat)
  } else if (scale == "log") {
    mat <- log10(mat + 1)
  } else if (scale == "sq") {
    mat <- mat^2
  } else if (scale != "none") {
    stop("Do not recognize scale input")
  }
  if (normalize == "rec_norm") {
    if (ncol(mat) > 1) {
      mat <- do_norm(mat, "row")
    }
  } else if (normalize == "lig_norm") {
    if (nrow(mat) > 1) {
      mat <- do_norm(mat, "col")
    }
  } else if (normalize != "none") {
    stop("Do not recognize normalize input")
  }
  links <- c()
  weights <- c()
  for (rcl in rownames(mat)) {
    for (lcl in colnames(mat)) {
      if (mat[rcl, lcl] == 0) {
        next
      }
      L <- gsub("L_", "", lcl)
      R <- gsub("R_", "", rcl)
      links <- c(links, as.character(L), as.character(R))
      weights[paste0(L, "|", R)] <- mat[rcl, lcl]
    }
  }
  graph <- igraph::graph(links)
  # Get vert colors and scale size if desired.
  igraph::V(graph)$label.dist <- 1.5
  igraph::V(graph)$label.color <- "black"
  v_cols <- cols[names(igraph::V(graph))]
  if (scale_by == "lig_sig" & all(gsub("L_", "", colnames(mat)) %in% names(igraph::V(graph)))) {
    vals <- asinh(colSums(mat))
    vals <- vals[paste0("L_", names(igraph::V(graph)))]
    igraph::V(graph)$size <- vals * vert_scale
  } else if (scale_by == "rec_sig" & all(gsub("R_", "", rownames(mat)) %in% names(igraph::V(graph)))) {
    vals <- asinh(rowSums(mat))
    vals <- vals[paste0("R_", names(igraph::V(graph)))]
    igraph::V(graph)$size <- vals * vert_scale
  } else {
    igraph::V(graph)$size <- vert_scale
  }
  # Get vert angle for labeling circos plot
  if (layout == "circle") {
    v_angles <- seq(length(igraph::V(graph)))
    v_angles <- -2 * pi * (v_angles - 1) / length(v_angles)
    igraph::V(graph)$label.degree <- v_angles
  }
  names(v_cols) <- c()
  igraph::V(graph)$color <- v_cols
  # Get edge color. weights, and lines
  weights <- weights[attr(igraph::E(graph), "vnames")]
  e_cols <- c()
  for (e in names(weights)) {
    lcl <- strsplit(e, "|", fixed = TRUE)[[1]][1]
    e_cols <- c(e_cols, cols[lcl])
  }
  names(weights) <- c()
  names(e_cols) <- c()
  igraph::E(graph)$width <- weights * edge_weight
  igraph::E(graph)$color <- e_cols
  igraph::E(graph)$arrow.size <- 0
  igraph::E(graph)$curved <- 0.5
  # Get edge colors
  if (layout == "random") {
    l <- igraph::layout_randomly(graph)
  } else if (layout == "circle") {
    l <- igraph::layout_in_circle(graph)
  } else if (layout == "sphere") {
    l <- igraph::layout_on_sphere(graph)
  } else if (layout == "fr") {
    l <- igraph::layout_with_fr(graph)
  } else if (layout == "kk") {
    l <- igraph::layout_with_kk(graph)
  }
  plot(graph, layout = l, main = plot_title, ...)
}

#' Create a gene association network
#'
#' Create a gene association network for genes from a given cluster. The
#' selected cluster acts as the receptor for the gene association network, so
#' only ligands, receptors, and features associated with the receptor cluster
#' will be included in the plot.
#'
#' @param dom Domino object with network built ([build_domino()])
#' @param clust Receptor cluster to create the gene association network for. A vector of clusters may be provided.
#' @param OutgoingSignalingClust Vector of clusters to plot the outgoing signaling from
#' @param class_cols Named vector of colors used to color classes of vertices. Values must be colors and names must be classes ('rec', 'lig', and 'feat' for receptors, ligands, and features.).
#' @param cols Named vector of colors for individual genes. Genes not included in this vector will be colored according to class_cols.
#' @param lig_scale FALSE or a numeric value to scale the size of ligand vertices based on z-scored expression in the data set.
#' @param layout Type of layout to use. Options are 'grid', 'random', 'sphere', 'circle', 'fr' for Fruchterman-Reingold force directed layout, and 'kk' for Kamada Kawai for directed layout.
#' @param ... Other parameters to pass to plot() with an [igraph](https://r.igraph.org/) object. See [igraph](https://r.igraph.org/) manual for options.
#' @return An igraph plot rendered to the active graphics device
#' @export gene_network
#' @examples
#' #basic usage
#' example(build_domino, echo = FALSE)
#' gene_network(
#'  pbmc_dom_built_tiny, clust = "CD8_T_cell", 
#'  OutgoingSignalingClust = "CD14_monocyte")
#'
gene_network <- function(dom, clust = NULL, OutgoingSignalingClust = NULL, 
    class_cols = c(lig = "#FF685F",rec = "#47a7ff", feat = "#39C740"),
    cols = NULL, lig_scale = 1, layout = "grid", ...) {
  if (!dom@misc[["build"]]) {
    warning("Please build a signaling network with domino_build prior to plotting.")
  }
  if (!length(dom@clusters)) {
    warning("This domino object wasn't build with clusters. The global signaling network will be shown.")
    lig_scale <- FALSE
  }
  # Get connections between TF and recs for clusters
  if (length(dom@clusters)) {
    all_sums <- c()
    tfs <- c()
    cl_with_signaling <- c()
    for (cl in as.character(clust)) {
      # Check if signaling exists for target cluster
      mat <- dom@cl_signaling_matrices[[cl]]
      if (dim(mat)[1] == 0) {
        message("No signaling found for ", cl, " under build parameters.")
        (next)()
      }
      all_sums <- c(all_sums, rowSums(mat))
      tfs <- c(tfs, dom@linkages$clust_tf[[cl]])
      cl_with_signaling <- c(cl_with_signaling, cl)
    }
    all_sums <- all_sums[!duplicated(names(all_sums))]
    # If no signaling for target clusters then don't do anything
    if (length(tfs) == 0) {
      message("No signaling found for provided clusters")
      return()
    }
  } else {
    tfs <- dom@linkages$clust_tf[["clust"]]
  }
  links <- c()
  all_recs <- c()
  all_tfs <- c()
  for (cl in as.character(clust)) {
    for (tf in tfs) {
      recs <- dom@linkages$clust_tf_rec[[cl]][[tf]]
      all_recs <- c(all_recs, recs)
      if (length(recs)) {
        all_tfs <- c(all_tfs, tf)
      }
      for (rec in recs) {
        links <- c(links, rec, tf)
      }
    }
  }
  all_recs <- unique(all_recs)
  all_tfs <- unique(all_tfs)
  # Recs to ligs
  if (length(dom@clusters)) {
    allowed_ligs <- c()
    for (cl in cl_with_signaling) {
      if (!is.null(OutgoingSignalingClust)) {
        OutgoingSignalingClust <- paste0("L_", OutgoingSignalingClust)
        mat <- dom@cl_signaling_matrices[[cl]][, OutgoingSignalingClust]
        if (is.null(dim(mat))) {
          allowed_ligs <- names(mat[mat > 0])
          all_sums <- mat[mat > 0]
        } else {
          allowed_ligs <- rownames(mat[rowSums(mat) > 0, ]) # I remove any ligands with zeroes for all clusters
          all_sums <- rowSums(mat[rowSums(mat) > 0, ])
        }
      } else {
        allowed_ligs <- rownames(dom@cl_signaling_matrices[[cl]])
      }
    }
  } else {
    allowed_ligs <- rownames(dom@z_scores)
  }
  # Remove ligs not expressed in data set if desired
  all_ligs <- c()
  for (rec in all_recs) {
    ligs <- dom@linkages$rec_lig[[rec]]
    for (lig in ligs) {
      if (length(which(allowed_ligs == lig))) {
        links <- c(links, lig, rec)
        all_ligs <- c(all_ligs, lig)
      }
    }
  }
  all_ligs <- unique(all_ligs)
  # Make the graph
  graph <- igraph::graph(links)
  graph <- igraph::simplify(graph, remove.multiple = TRUE, remove.loops = FALSE)
  v_cols <- rep("#BBBBBB", length(igraph::V(graph)))
  names(v_cols) <- names(igraph::V(graph))
  v_cols[all_tfs] <- class_cols["feat"]
  v_cols[all_recs] <- class_cols["rec"]
  v_cols[all_ligs] <- class_cols["lig"]
  if (!is.null(cols)) {
    v_cols[names(cols)] <- cols
  }
  names(v_cols) <- c()
  igraph::V(graph)$color <- v_cols
  v_size <- rep(10, length(igraph::V(graph)))
  names(v_size) <- names(igraph::V(graph))
  if (lig_scale) {
    all_sums <- all_sums[names(all_sums) %in% names(v_size)]
    v_size[names(all_sums)] <- 0.5 * all_sums * lig_scale
  }
  names(v_size) <- c()
  igraph::V(graph)$size <- v_size
  igraph::V(graph)$label.degree <- pi
  igraph::V(graph)$label.offset <- 2
  igraph::V(graph)$label.color <- "black"
  igraph::V(graph)$frame.color <- "black"
  igraph::E(graph)$width <- 0.5
  igraph::E(graph)$arrow.size <- 0
  igraph::E(graph)$color <- "black"
  if (layout == "grid") {
    l <- matrix(0, ncol = 2, nrow = length(igraph::V(graph)))
    rownames(l) <- names(igraph::V(graph))
    l[all_ligs, 1] <- -0.75
    l[all_recs, 1] <- 0
    l[all_tfs, 1] <- 0.75
    l[all_ligs, 2] <- (seq_along(all_ligs) / mean(seq_along(all_ligs)) - 1) * 2
    l[all_recs, 2] <- (seq_along(all_recs) / mean(seq_along(all_recs)) - 1) * 2
    l[all_tfs, 2] <- (seq_along(all_tfs) / mean(seq_along(all_tfs)) - 1) * 2
    rownames(l) <- c()
  } else if (layout == "random") {
    l <- igraph::layout_randomly(graph)
  } else if (layout == "circle") {
    l <- igraph::layout_in_circle(graph)
  } else if (layout == "sphere") {
    l <- igraph::layout_on_sphere(graph)
  } else if (layout == "fr") {
    l <- igraph::layout_with_fr(graph)
  } else if (layout == "kk") {
    l <- igraph::layout_with_kk(graph)
  }
  plot(graph, layout = l, main = paste0("Signaling ", OutgoingSignalingClust, " to ", clust), ...)
  return(invisible(list(graph = graph, layout = l)))
}

#' Create a heatmap of features organized by cluster
#'
#' Creates a heatmap of transcription factor activation scores by cells grouped by cluster.
#'
#' @param dom Domino object with network built ([build_domino()])
#' @param bool Boolean indicating whether the heatmap should be continuous or boolean. If boolean then bool_thresh will be used to determine how to define activity as positive or negative.
#' @param bool_thresh Numeric indicating the threshold separating 'on' or 'off' for feature activity if making a boolean heatmap.
#' @param title Either a string to use as the title or a boolean describing whether to include a title. In order to pass the 'main' parameter to  [ComplexHeatmap::Heatmap()]  you must set title to FALSE.
#' @param norm Boolean indicating whether or not to normalize the transcrption factors to their max value.
#' @param feats Either a vector of features to include in the heatmap or 'all' for all features. If left NULL then the features selected for the signaling network will be shown.
#' @param ann_cols Boolean indicating whether to include cell cluster as a column annotation. Colors can be defined with cols. If FALSE then custom annotations can be passed to [ComplexHeatmap::Heatmap()].
#' @param cols Named vector of colors to annotate cells by cluster color. Values are taken as colors and names as cluster. If left as NULL then default ggplot colors will be generated.
#' @param min_thresh Minimum threshold for color scaling if not a boolean heatmap
#' @param max_thresh Maximum threshold for color scaling if not a boolean heatmap
#' @param ... Other parameters to pass to  [ComplexHeatmap::Heatmap()] . Note that to use the 'main' parameter of  [ComplexHeatmap::Heatmap()]  you must set title = FALSE and to use 'annCol' or 'annColors' ann_cols must be FALSE.
#' @return A heatmap rendered to the active graphics device
#' @export feat_heatmap
#' @examples 
#' #basic usage
#' example(build_domino, echo = FALSE)
#' feat_heatmap(pbmc_dom_built_tiny)
#' #using thresholds
#' feat_heatmap(
#'  pbmc_dom_built_tiny, min_thresh = 0.1, 
#'  max_thresh = 0.6, norm = TRUE, bool = FALSE)
#' 
feat_heatmap <- function(
    dom, feats = NULL, bool = FALSE, bool_thresh = 0.2, title = TRUE, norm = FALSE,
    cols = NULL, ann_cols = TRUE, min_thresh = NULL, max_thresh = NULL, ...) {
  if (!length(dom@clusters)) {
    warning("This domino object wasn't built with clusters. Cells will not be ordered.")
    ann_cols <- FALSE
  }
  mat <- dom@features
  cl <- dom@clusters
  cl <- sort(cl)
  if (norm & (!is.null(min_thresh) | !is.null(max_thresh))) {
    warning("You are using norm with min_thresh and max_thresh. Note that values will be thresholded AFTER normalization.")
  }
  if (norm) {
    mat <- do_norm(mat, "row")
  }
  if (!is.null(min_thresh)) {
    mat[which(mat < min_thresh)] <- min_thresh
  }
  if (!is.null(max_thresh)) {
    mat[which(mat > max_thresh)] <- max_thresh
  }
  if (bool) {
    cp <- mat
    cp[which(mat >= bool_thresh)] <- 1
    cp[which(mat < bool_thresh)] <- 0
    mat <- cp
  }
  if (title == TRUE) {
    title <- "Feature expression by cluster"
  }
  if (is.null(feats)) {
    feats <- c()
    links <- dom@linkages$clust_tf
    for (i in links) {
      feats <- c(feats, i)
    }
    feats <- unique(feats)
  } else if (feats[1] != "all") {
    mid <- match(feats, rownames(dom@features))
    na <- which(is.na(mid))
    na_feats <- paste(feats[na], collapse = " ")
    if (length(na) != 0) {
      message("Unable to find ", na_feats)
      feats <- feats[-na]
    }
  } else if (feats == "all") {
    feats <- rownames(mat)
  }
  if (length(cl)) {
    mat <- mat[feats, names(cl)]
  }
  if (ann_cols) {
    ac <- list(Cluster = cl)
    names(ac[[1]]) <- c()
    if (is.null(cols)) {
      cols <- ggplot_col_gen(length(levels(cl)))
      names(cols) <- levels(cl)
    }
    # cols <- list(Cluster = cols)
    feat_anno <- columnAnnotation(
      Cluster = cl,
      col = list(Cluster = cols)
    )
  }
  if (title != FALSE & ann_cols != FALSE) {
    Heatmap(
      mat,
      name = "feature\nactivity",
      top_annotation = feat_anno,
      cluster_columns = FALSE, show_column_names = FALSE,
      column_title = title,
      ...
    )
  } else if (title == FALSE & ann_cols != FALSE) {
    Heatmap(
      mat,
      name = "feature\nactivity",
      top_annotation = feat_anno,
      cluster_columns = FALSE, show_column_names = FALSE,
      ...
    )
  } else if (title != FALSE & ann_cols == FALSE) {
    Heatmap(
      mat,
      name = "feature\nactivity",
      top_annotation = feat_anno,
      cluster_columns = FALSE, show_column_names = FALSE,
      column_title = title,
      ...
    )
  } else if (title == FALSE & ann_cols == FALSE) {
    Heatmap(
      mat,
      name = "feature\nactivity",
      cluster_columns = FALSE, show_column_names = FALSE,
      ...
    )
  }
}

#' Create a heatmap of correlation between receptors and transcription factors
#'
#' Creates a heatmap of correlation values between receptors and transcription
#' factors either with boolean threshold or with continuous values displayed
#'
#' @param dom Domino object with network built ([build_domino()])
#' @param bool Boolean indicating whether the heatmap should be continuous or boolean. If boolean then bool_thresh will be used to determine how to define activity as positive or negative.
#' @param bool_thresh Numeric indicating the threshold separating 'on' or 'off' for feature activity if making a boolean heatmap.
#' @param title Either a string to use as the title or a boolean describing whether to include a title. In order to pass the 'main' parameter to  [ComplexHeatmap::Heatmap()]  you must set title to FALSE.
#' @param feats Either a vector of features to include in the heatmap or 'all' for all features. If left NULL then the features selected for the signaling network will be shown.
#' @param recs Either a vector of receptors to include in the heatmap or 'all' for all receptors. If left NULL then the receptors selected in the signaling network connected to the features plotted will be shown.
#' @param mark_connections Boolean indicating whether to add an 'x' in cells where there is a connected receptor or TF. Default FALSE.
#' @param ... Other parameters to pass to  [ComplexHeatmap::Heatmap()] . Note that to use the 'main' parameter of  [ComplexHeatmap::Heatmap()]  you must set title = FALSE and to use 'annCol' or 'annColors' ann_cols must be FALSE.
#' @return A heatmap rendered to the active graphics device
#' @export cor_heatmap
#' @examples 
#' example(build_domino, echo = FALSE)
#' #basic usage
#' cor_heatmap(pbmc_dom_built_tiny, title = "PBMC R-TF Correlations")
#' #show correlations above a specific value
#' cor_heatmap(pbmc_dom_built_tiny, bool = TRUE, bool_thresh = 0.1)
#' #identify combinations that are connected
#' cor_heatmap(pbmc_dom_built_tiny, bool = FALSE, mark_connections = TRUE)
#'  
cor_heatmap <- function(
    dom, bool = FALSE, bool_thresh = 0.15, title = TRUE, feats = NULL, recs = NULL,
    mark_connections = FALSE, ...) {
  mat <- dom@cor
  if (bool) {
    cp <- mat
    cp[which(mat >= bool_thresh)] <- 1
    cp[which(mat < bool_thresh)] <- 0
    mat <- cp
  }
  if (title == TRUE) {
    title <- "Correlation of features and receptors"
  }
  if (is.null(feats)) {
    feats <- c()
    links <- dom@linkages$clust_tf
    for (i in links) {
      feats <- c(feats, i)
    }
    feats <- unique(feats)
  } else if (feats[1] != "all") {
    mid <- match(feats, rownames(dom@features))
    na <- which(is.na(mid))
    na_feats <- paste(feats[na], collapse = " ")
    if (length(na) != 0) {
      message("Unable to find ", na_feats)
      feats <- feats[-na]
    }
  } else if (identical(feats, "all")) {
    feats <- rownames(mat)
  }
  if (is.null(recs)) {
    recs <- c()
    links <- dom@linkages$tf_rec
    for (feat in feats) {
      feat_recs <- links[[feat]]
      if (length(feat_recs) > 0) {
        recs <- c(recs, feat_recs)
      }
    }
    recs <- unique(recs)
  } else if (identical(recs, "all")) {
    recs <- rownames(mat)
  }
  mat <- mat[recs, feats]
  if (mark_connections) {
    cons <- mat
    cons[] <- ""
    for (feat in feats) {
      feat_recs <- dom@linkages$tf_rec[[feat]]
      if (length(feat_recs)) {
        cons[feat_recs, feat] <- "X"
      }
    }
  }
  if (title != FALSE & mark_connections) {
    Heatmap(
      mat,
      name = "rho",
      column_title = title,
      cell_fun = function(j, i, x, y, w, h, col){
        grid.text(
          cons[i,j], x, y,
          gp = gpar(col = "#000000")
        )
      },
      ...
    )
  } else {
    Heatmap(
      mat, 
      name = "rho",
      ...
    )
  }
}

#' Create a correlation plot between TF and receptor
#'
#' Create a correlation plot between transcription factor activation score and receptor expression
#'
#' @param dom Domino object with network built ([build_domino()])
#' @param tf Target TF for plottting AUC score
#' @param rec Target receptor for plotting expression
#' @param remove_rec_dropout Whether to remove cells with zero expression for plot. This should match the same setting as in [build_domino()].
#' @param ... Other parameters to pass to [ggpubr::ggscatter()].
#' @return A ggplot scatter plot rendered in the active graphics device
#' @export cor_scatter
#' @examples
#' example(build_domino, echo = FALSE)
#' cor_scatter(pbmc_dom_built_tiny, "FLI1","CXCR3")
#'
cor_scatter <- function(dom, tf, rec, remove_rec_dropout = TRUE, ...) {
  if (remove_rec_dropout) {
    keep_id <- which(dom@counts[rec, ] > 0)
    rec_z_scores <- dom@z_scores[rec, keep_id]
    tar_tf_scores <- dom@features[tf, keep_id]
  } else {
    rec_z_scores <- dom@z_scores[rec, ]
    tar_tf_scores <- dom@features[tf, ]
  }
  dat <- data.frame(rec = rec_z_scores, tf = tar_tf_scores)
  ggscatter(dat, x = "rec", y = "tf", add = "reg.line", conf.int = FALSE, cor.coef = FALSE,
    cor.method = "pearson", xlab = rec, ylab = tf, size = 0.25, ...)
}

#' Plot expression of a receptor's ligands by other cell types as a chord plot
#'
#' Creates a chord plot of expression of ligands that can activate a specified
#' receptor where chord widths correspond to mean ligand expression by the cluster.
#'
#' @param dom Domino object that has undergone network building with [build_domino()]
#' @param receptor Name of a receptor active in at least one cell type in the domino object
#' @param ligand_expression_threshold Minimum mean expression value of a ligand by a cell type for a chord to be rendered between the cell type and the receptor
#' @param cell_idents Vector of cell types from cluster assignments in the domino object to be included in the plot.
#' @param cell_colors Named vector of color names or hex codes where names correspond to the plotted cell types and the color values
#' @return Renders a circos plot to the active graphics device
#' @export circos_ligand_receptor
#' @examples 
#' example(build_domino, echo = FALSE)
#' #basic usage
#' circos_ligand_receptor(pbmc_dom_built_tiny, receptor = "CXCR3")
#' #specify colors
#' cols = c("red", "orange", "green")
#' names(cols) = dom_clusters(pbmc_dom_built_tiny)
#' circos_ligand_receptor(pbmc_dom_built_tiny, receptor = "CXCR3", cell_colors = cols)
#' 
circos_ligand_receptor <- function(
    dom, receptor, ligand_expression_threshold = 0.01, cell_idents = NULL,
    cell_colors = NULL) {
  # pull signaling information from the domino result
  ligands <- dom@linkages$rec_lig[[receptor]]
  
  if(is.null(cell_idents)){
    # default to all cluster labels in domino object in alphabetical order
    cell_idents <- sort(unique(dom@clusters))
  }
  
  signaling_df <- obtain_circos_expression(
    dom = dom, receptor = receptor, ligands = ligands,
    ligand_expression_threshold = ligand_expression_threshold,
    cell_idents = cell_idents
  )
  # render circos plot
  render_circos_ligand_receptor(
    signaling_df = signaling_df, receptor = receptor, 
    cell_colors = cell_colors, 
    ligand_expression_threshold = ligand_expression_threshold
  )
}


#' Obtain Circos Expression
#' 
#' Pull expression data from a domino object and format for plotting as a receptor-oriented circos plot.
#' 
#' @param dom Domino object that has undergone network building with build_domino()
#' @param receptor Name of a receptor active in at least one cell type in the domino object
#' @param ligands Character vector of ligands capable of interaction with the receptor
#' @param ligand_expression_threshold Minimum mean expression value of a ligand by a cell type for a chord to be rendered between the cell type and the receptor
#' @param cell_idents Vector of cell types from cluster assignments in the domino object to be included in the plot.
#' @return a data frame where each row describes plotting parameters of ligand-receptor interactions to pass to render_circos_ligand_receptor()
#' @export obtain_circos_expression
#' @examples 
#' example(build_domino, echo = FALSE)
#' #basic usage
#' obtain_circos_expression(pbmc_dom_built_tiny, receptor = "CXCR3", ligands = "CCL20")
#' 

obtain_circos_expression <- function(dom, receptor, ligands, ligand_expression_threshold = 0.01, cell_idents = NULL){
  signaling_df <- NULL
  # obtain expression values from cl_signaling matrices
  active_chk <- vapply(
    dom@linkages$clust_rec, 
    FUN.VALUE = logical(1), FUN = function(x) {receptor %in% x}
  )
  if (sum(active_chk)) {
    # obtain a signaling matrix where receptor is active
    active_cell <- names(active_chk[active_chk == TRUE])
    sig <- dom@cl_signaling_matrices[active_cell][[1]]
    cell_names <- gsub("^L_", "", colnames(sig))
    
    
    # ensure only ligands present in the signaling matrix are considered
    lig_check <- ligands %in% rownames(sig)
    if(length(lig_check) != sum(lig_check)){
      message(paste0(
        "Ligands: ", paste(ligands[!lig_check], collapse = ","),
        " of receptor ", receptor, " are listed in the rl_map, but not present in the signaling matrix."
      ))
      if(sum(lig_check) == 0){
        stop(paste0("No ligands of receptor ", receptor, " present in signaling matrix."))
      } else {
        message(paste0("Only ligands: ", paste(ligands[lig_check], collapse = ","), " will be considered."))
      }
    }
    ligands <- ligands[lig_check]
    
    lig_signal_ls <- lapply(
      setNames(ligands, nm = ligands), 
      function(l){
        df <- data.frame(
          origin = paste0(cell_names, "-", l), 
          destination = receptor, 
          mean.expression = unname(sig[rownames(sig) == l, ]),
          sender = cell_names,
          ligand = l,
          receptor = receptor
        )
      }
    )
    signaling_df <- purrr::list_rbind(lig_signal_ls)
  } else {
    stop("No clusters have active ", receptor, " signaling")
  }
  
  if(!is.null(cell_idents)){
    signaling_df <- signaling_df[signaling_df$sender %in% cell_idents,]
  }
  
  signaling_df$mean.expression[is.na(signaling_df$mean.expression)] <- 0
  # create a scaled mean expression plot for coord widths greater than 1 by dividing by the max
  # expression [range (0-1)] scaled.mean will only be used when the max expression is > 1
  signaling_df$scaled.mean.expression <- signaling_df$mean.expression / max(signaling_df$mean.expression)
  # exit function if no ligands are expressed above ligand expression threshold
  if (sum(signaling_df[["mean.expression"]] > ligand_expression_threshold) == 0) {
    stop("No ligands of ", receptor, " exceed ligand expression threshold.")
  }
  signaling_df["ligand.arc"] <- 1
  # receptor arc will always sum to 4 no matter how many ligands and cell idents are plotted
  signaling_df["receptor.arc"] <- 4 / (nrow(signaling_df))
  
  return(signaling_df)
}

#' Render Circos Ligand Receptor Plot
#' 
#' Renders a circos plot from the output of [obtain_circos_expression()] to the active graphics device
#' 
#' @param signaling_df Data frame output from [obtain_circos_expression()]
#' @param receptor Name of a receptor active in at least one cell type in the domino object
#' @param ligand_expression_threshold Minimum mean expression value of a ligand by a cell type for a chord to be rendered between the cell type and the receptor
#' @param cell_colors Named vector of color names or hex codes where names correspond to the plotted cell types and the color values
#' @return a circlize plot is rendered to the active graphics device
#' @export render_circos_ligand_receptor
#' @examples 
#' example(build_domino, echo = FALSE)
#' #basic usage
#' circos_df <- obtain_circos_expression(pbmc_dom_built_tiny, receptor = "CXCR3", ligands = "CCL20")
#' render_circos_ligand_receptor(signaling_df = circos_df, receptor = "CXCR3")
#'

render_circos_ligand_receptor <- function(
    signaling_df, receptor, cell_colors = NULL, ligand_expression_threshold = 0.01
  ){
  ligands <- sort(unique(signaling_df$ligand))
  
  # colors for [cell_ident] arcs
  cell_idents <- sort(unique(signaling_df$sender))
  if (is.null(cell_colors)) {
    cell_colors <- ggplot_col_gen(length(cell_idents))
    names(cell_colors) <- cell_idents
  }
  # ensure the vector cell_ident colors is in alphabetical order so that the legend matches the plot
  cell_colors <- cell_colors[sort(names(cell_colors))]
  
  # chords colored by ligand type
  lig_colors <- ggplot_col_gen(length(ligands))
  names(lig_colors) <- ligands
  origin_cols <- vapply(
    signaling_df$ligand, FUN.VALUE = character(1), FUN = function(l){
      return(lig_colors[l])
    }
  )
  
  # first index of color vector set to white to hid receptor arc
  grid_col <- c("#FFFFFF", origin_cols)
  names(grid_col) <- c(receptor, signaling_df$origin)
  
  # name grouping based on [cell_ident]
  l_name_mask <- paste0(paste(paste0("-", ligands), collapse = "|"), "$")
  arc_name <- c(receptor, gsub(l_name_mask, "", signaling_df$origin))
  group <- structure(arc_name, names = c(receptor, signaling_df$origin))
  
  circlize::circos.clear()
  circlize::circos.par(start.degree = 0)
  circlize::chordDiagram(
    signaling_df[,c("origin", "destination", "ligand.arc", "receptor.arc")], group = group, 
    grid.col = grid_col, link.visible = FALSE, 
    annotationTrack = c("grid"),
    preAllocateTracks = list(
      track.height = circlize::mm_h(4), 
      track.margin = c(circlize::mm_h(2), 0)
    ), 
    big.gap = 2
  )
  for (send in signaling_df$origin) {
    if (signaling_df[signaling_df$origin == send, ][["mean.expression"]] > ligand_expression_threshold) {
      if (max(signaling_df[["mean.expression"]]) > 1) {
        expr <- signaling_df[signaling_df$origin == send, ][["scaled.mean.expression"]]
        max_width <- signif(max(signaling_df[["mean.expression"]]), 2)
      } else {
        expr <- signaling_df[signaling_df$origin == send, ][["mean.expression"]]
        max_width <- 1
      }
      circlize::circos.link(send, c(0.5 - (expr / 2), 0.5 + (expr / 2)), receptor, 2, col = paste0(
        grid_col[[send]],
        "88"
      ))
    }
  }
  sector_names <- circlize::get.all.sector.index()
  cell_sectors <- cell_idents[cell_idents %in% signaling_df$sender]
  
  # pick cell sectors based on the start of the sector name being the cell type
  for (cell in cell_sectors) {
    row_pick <- sector_names[startsWith(sector_names, cell)]
    if (length(row_pick)) {
      circlize::highlight.sector(
        sector_names[startsWith(sector_names, cell)],
        track.index = 1, col = cell_colors[cell], 
        text = cell, cex = 1, facing = "inside", text.col = "black",
        niceFacing = FALSE, text.vjust = -1.5
      )
    }
  }
  
  # highlight receptor sector
  circlize::highlight.sector(
    sector_names[startsWith(sector_names, receptor)],
    track.index = 1, col = "#FFFFFF", 
    text = receptor, cex = 1.5, facing = "clockwise", text.col = "black", 
    niceFacing = TRUE, pos = 4
  )
  # create legends
  lgd_cells <- ComplexHeatmap::Legend(
    at = as.character(cell_idents), type = "grid", legend_gp = grid::gpar(fill = cell_colors),
    title_position = "topleft", title = "cell identity"
  )
  lgd_ligands <- ComplexHeatmap::Legend(
    at = ligands, type = "grid", legend_gp = grid::gpar(fill = lig_colors), title_position = "topleft",
    title = "ligand"
  )
  chord_width <- 10 / (4 + length(cell_idents) * length(ligands))
  lgd_chord <- ComplexHeatmap::Legend(
    at = c(ligand_expression_threshold, max_width), col_fun = circlize::colorRamp2(c(
      ligand_expression_threshold,
      max_width
    ), c("#DDDDDD", "#DDDDDD")), legend_height = grid::unit(chord_width, "in"), title_position = "topleft",
    title = "ligand expression"
  )
  lgd_list_vertical <- ComplexHeatmap::packLegend(lgd_cells, lgd_ligands, lgd_chord)
  ComplexHeatmap::draw(lgd_list_vertical, x = grid::unit(0.02, "npc"), y = grid::unit(0.98, "npc"), just = c("left", "top"))
}

#' Plot differential linkages among domino results ranked by a comparative statistic
#'
#' Plot differential linkages among domino results ranked by a comparative statistic
#'
#' @param differential_linkages a data frame output from the [test_differential_linkages()] function
#' @param test_statistic column name of differential_linkages where the test statistic used for ranking linkages is stored (ex. 'p.value')
#' @param stat_range a two value vector of the minimum and maximum values of test_statistic for plotting linkage features
#' @param stat_ranking 'ascending' (lowest value of test statisic is colored red and plotted at the top) or 'descending' (highest value of test statistic is colored red and plotted at the top).
#' @param group_palette a named vector of colors to use for each group being compared
#' @return A heatmap-class object of features ranked by test_statistic annotated with the proportion of subjects that showed active linkage of the features.
#' @export
#' @examples
#' example(build_domino, echo = FALSE)
#' example(test_differential_linkages, echo = FALSE)
#' plot_differential_linkages(
#'  differential_linkages = tiny_differential_linkage_c1,
#'  test_statistic = "p.value",
#'  stat_ranking = "ascending"
#' )
#'
plot_differential_linkages <- function(
    differential_linkages, test_statistic, stat_range = c(0, 1),
    stat_ranking = c("ascending", "descending"), group_palette = NULL) {
  if (!test_statistic %in% colnames(differential_linkages)) {
    stop("test statistic '", test_statistic, "' not present in colnames(differential_linkages)")
  }
  if (identical(stat_ranking, c("ascending", "descending"))) {
    warning("stat_ranking order not specified. Defaulting to ascending order")
    stat_ranking <- "ascending"
  }
  if (!stat_ranking %in% c("ascending", "descending")) {
    stop("stat_ranking must be 'ascending' or 'descending'")
  }
  # limit to features within stat range
  df <- differential_linkages[differential_linkages[[test_statistic]] >= stat_range[1] & differential_linkages[[test_statistic]] <=
    stat_range[2], ]
  if (nrow(df) == 0) {
    stop("No features with '", test_statistic, "' within stat_range")
  }
  # order df by plot statistic
  if (stat_ranking == "ascending") {
    df <- df[order(df[[test_statistic]], df[["total_count"]], decreasing = FALSE), ]
    stat_gradient <- c("#FF0000", "#FFFFFF")
  }
  if (stat_ranking == "descending") {
    df <- df[order(df[[test_statistic]], df[["total_count"]], decreasing = TRUE), ]
    stat_gradient <- c("#FFFFFF", "#FF0000")
  }
  # values from test result for plotting
  cluster <- unique(df[["cluster"]])
  g_names_full <- colnames(df)[grepl("_n$", colnames(df)) & !grepl("^total_", colnames(df))]
  g_names <- gsub("_n", "", g_names_full)
  # proportion bar for linkage feature in all subjects
  ha_subject <- ComplexHeatmap::HeatmapAnnotation(
    subjects = ComplexHeatmap::anno_barplot(matrix(ncol = 2, c(
      df[["total_count"]],
      df[["total_n"]] - df[["total_count"]]
    )), gp = grid::gpar(fill = c("black", "white"))), which = "row",
    annotation_name_gp = grid::gpar(fontsize = 8)
  )
  ha_subject@anno_list$subjects@label <- "All\nSubjects"
  # row annotation of linkage feature names
  ha_name <- ComplexHeatmap::rowAnnotation(feat = ComplexHeatmap::anno_text(df[["feature"]], location = 0, rot = 0))
  # plotted statistic for ordering results
  mat <- matrix(df[[test_statistic]], ncol = 1)
  rownames(mat) <- df[["feature"]]
  plot <- ComplexHeatmap::Heatmap(matrix = mat, cluster_rows = FALSE, left_annotation = ha_name, cell_fun = function(
      j,
      i, x, y, width, height, fill) {
    grid::grid.text(sprintf("%.3f", mat[i, j]), x, y, gp = grid::gpar(fontsize = 6))
  }, column_title = paste0(cluster, ": ", test_statistic), name = test_statistic, col = circlize::colorRamp2(
    breaks = stat_range,
    colors = stat_gradient
  ), height = nrow(mat) * grid::unit(0.25, "in"), width = grid::unit(1, "in")) + ha_subject
  # generate an heatmap annotation for each category
  if (is.null(group_palette)) {
    group_palette <- ggplot_col_gen(length(g_names))
    names(group_palette) <- g_names
  }
  for (i in seq_along(g_names)) {
    g <- g_names[i]
    g_count <- paste0(g, "_count")
    g_n <- paste0(g, "_n")
    ha <- ComplexHeatmap::HeatmapAnnotation(
      group = ComplexHeatmap::anno_barplot(matrix(ncol = 2, c(df[[g_count]], df[[g_n]] -
        df[[g_count]])), gp = grid::gpar(fill = c(group_palette[g], "#FFFFFF"))), name = g, which = "row",
      annotation_name_gp = grid::gpar(fontsize = 8)
    )
    ha@anno_list$group@label <- g
    plot <- plot + ha
  }
  return(plot)
}

#' Normalize a matrix to its max value by row or column
#'
#' Normalizes a matrix to its max value by row or column
#'
#' @param mat Matrix to be normalized
#' @param dir Direction to normalize the matrix (either "row" for row or "col" for column)
#' @return A normalized matrix in the direction specified.
#' @keywords internal
#'
do_norm <- function(mat, dir) {
  if (dir == "row") {
    mat <- t(apply(mat, 1, function(x) {
      x / max(x)
    }))
    return(mat)
  } else if (dir == "col") {
    mat <- apply(mat, 2, function(x) {
      x / max(x)
    })
    return(mat)
  }
}

#' Generate ggplot colors
#'
#' Accepts a number of colors to generate and generates a ggplot color spectrum.
#'
#' @param n Number of colors to generate
#' @return A vector of colors according to ggplot color generation.
#' @keywords internal
#' 
ggplot_col_gen <- function(n) {
  hues <- seq(15, 375, length = n + 1)
  return(grDevices::hcl(h = hues, l = 65, c = 100)[seq_len(n)])
}
