#' compute group summarized score and order genes based on processed scores
#'
#' @inheritParams top_markers_abs
#' @inheritParams top_markers_glm
#' @param use.glm logical, if to use [stats::glm()] to compute group mean score,
#'     if TRUE, also compute mean score difference as output
#' @param ... params for [top_markers_abs()] or [top_markers_glm()]
#'
#' @return a tibble with feature names, group labels and ordered processed scores
#' @export
#'
#' @examples
#' data <- matrix(rgamma(100, 2), 10, dimnames = list(1:10))
#' top_markers_init(data, label = rep(c("A", "B"), 5))
top_markers_init <- function(data, label, n = 10,
                             use.glm = TRUE,
                             batch = NULL,
                             scale = TRUE,
                             use.mgm = TRUE,
                             softmax = TRUE,
                             ...) {
  if (use.glm == TRUE) {
    data <- top_markers_glm(
      data = data,
      label = label,
      batch = batch,
      n = n,
      scale = scale,
      use.mgm = use.mgm,
      softmax = softmax,
      ...
    )
  } else {
    data <- top_markers_abs(
      data = data,
      label = label,
      n = n,
      scale = scale,
      use.mgm = use.mgm,
      softmax = softmax,
      ...
    )
  }

  return(data)
}

#' calculate group median, MAD or mean score and order genes based on scores
#'
#' @inheritParams scale_mgm
#' @inheritParams top_markers_glm
#' @param method character, specify metric to compute, can be one of "median",
#'     "mad", "mean"
#'
#' @return a tibble with feature names, group labels and ordered processed scores
#' @export
#'
#' @examples
#' data <- matrix(rgamma(100, 2), 10, dimnames = list(1:10))
#' top_markers_abs(data, label = rep(c("A", "B"), 5))
top_markers_abs <- function(data, label, n = 10,
                            pooled.sd = FALSE,
                            method = c("median", "mad", "mean"),
                            scale = TRUE, use.mgm = TRUE,
                            softmax = TRUE,
                            tau = 1) {
  method <- match.arg(method)
  if (scale && use.mgm) {
    data <- scale_mgm(expr = data, label = label, pooled.sd = pooled.sd)
  } else if (scale && !use.mgm) {
    ## scale scores on rows
    # mu_s <- sparseMatrixStats::rowMeans2(data, na.rm = TRUE)
    # sd_s <- sparseMatrixStats::rowSds(data, na.rm = TRUE)
    # data <- (data - mu_s) / sd_s

    data <- t(scale(t(data)))
    data[is.na(data)] <- 0 # assign 0 to NA when sd = 0
  }

  data <- data |>
    t() |>
    as.data.frame() |>
    dplyr::group_by(.dot = label) |> ## group by label
    dplyr::summarise_all(method, na.rm = TRUE) |> ## aggregate scores
    tidyr::gather("Genes", "Scores", -`.dot`) |> ## transform into long data
    dplyr::group_by(`.dot`) ## group by label again

  if (softmax == TRUE) {
    data <- data |>
      # dplyr::mutate(Scores = Scores / sd(Scores, na.rm = TRUE)) |> # norm by sd
      # dplyr::mutate(Scores = sigmoid(Scores)) |> # sigmoid
      # dplyr::mutate(Scores = tanh(Scores)) |> # tanh
      dplyr::mutate(Scores = softmax(Scores, tau = tau)) # softmax
  }

  data <- dplyr::slice_max(data, Scores, n = n) ## extract top n markers

  # ## softmax
  # if(softmax == TRUE)
  #   data <- dplyr::mutate(data, Scores = softmax(Scores, tau = tau))

  return(data)
}

#' calculate group mean score using glm and order genes based on scores difference
#'
#' @inheritParams scale_mgm
#' @param data matrix, features in row and samples in column
#' @param n integer, number of returned top genes for each group
#' @param family family for glm, details in [stats::glm()]
#' @param batch a vector of batch labels, default NULL
#' @param scale logical, if to scale data by row
#' @param use.mgm logical, if to scale data using [scale_mgm()]
#' @param softmax logical, if to apply softmax transformation on output
#' @param tau numeric, hyper parameter for softmax
#'
#' @return a tibble with feature names, group labels and ordered processed scores
#' @export
#'
#' @examples
#' data <- matrix(rgamma(100, 2), 10, dimnames = list(1:10))
#' top_markers_glm(data, label = rep(c("A", "B"), 5))
top_markers_glm <- function(data, label, n = 10,
                            family = gaussian(), # score are continuous non-negative, can use gamma or inverse.gaussian, if continuous and unbounded use gaussian, if discrete use poisson, if binary or proportions between [0,1] or binary freq counts use binomial
                            batch = NULL,
                            scale = TRUE, use.mgm = TRUE,
                            pooled.sd = FALSE,
                            # log = TRUE,
                            softmax = TRUE,
                            tau = 1) {
  label <- factor(label) # factorize label

  ## scale
  if (scale && !use.mgm) {
    ## scale scores on rows
    # mu_s <- sparseMatrixStats::rowMeans2(data, na.rm = TRUE)
    # sd_s <- sparseMatrixStats::rowSds(data, na.rm = TRUE)
    # data <- (data - mu_s) / sd_s

    data <- t(scale(t(data)))
    data[is.na(data)] <- 0 # assign 0 to NA when sd = 0
  } else if (scale && use.mgm) {
    data <- scale_mgm(expr = data, label = label, pooled.sd = pooled.sd)
  }

  # ## log score
  # if(log == TRUE) {
  #   data <- log(data + 1e-8)
  # }

  ## estimate betas based on given group and/or batch label
  if(is.null(batch)) {
    ## model with group label only
    betas <- apply(data, 1, \(s) glm(s ~ 0 + label, family = family)$coef)
  } else {
    ## factorize batch label
    batch <- factor(batch)
    ## model with both group and batch label
    betas <- apply(data, 1, \(s) glm(s ~ 0 + label + batch, family = family)$coef)
    ## only extract betas of group label
    betas <- betas[grep("^label", rownames(betas)), , drop = FALSE]
  }

  rownames(betas) <- gsub("label", "", rownames(betas))

  # ## compute logFC (1 vs all mean) for each group
  # betas <- apply(betas, 2, \(x) x - (sum(x) - x)/(length(x) - 1))

  ## compute logFC (1 vs max excluding self) for each group
  betas <- vapply(
    seq_len(nrow(betas)), \(i)
    betas[i, ] - sparseMatrixStats::colMaxs(betas[-i, , drop = FALSE]),
    rep(1, ncol(betas))
  ) |>
    t()
  rownames(betas) <- levels(label)

  data <- data.frame(.dot = rownames(betas), betas) |>
    tidyr::pivot_longer(-`.dot`, names_to = "Genes", values_to = "Scores") |>
    dplyr::group_by(`.dot`) ## group by label again

  if (softmax == TRUE) {
    data <- data |>
      # dplyr::mutate(Scores = Scores / sd(Scores, na.rm = TRUE)) |> # norm by sd
      # dplyr::mutate(Scores = sigmoid(Scores)) |> # sigmoid
      # dplyr::mutate(Scores = tanh(Scores)) |> # tanh
      dplyr::mutate(Scores = softmax(Scores, tau = tau)) # softmax
  }

  data <- dplyr::slice_max(data, Scores, n = n) ## extract top n markers

  # ## softmax
  # if(softmax == TRUE)
  #   data <- dplyr::mutate(data, Scores = softmax(Scores, tau = tau))

  return(data)
}

## sigmoid: [0, 1], multi-label, no need to sum to 1
sigmoid <- function(x) {
  x <- x / max(abs(x))
  1 / (1 + exp(-x))
}

## softmax: [0, 1], one-label, multi-class, sum to 1
softmax <- function(x, tau = 1) {
  x <- x / tau
  exp(x) / sum(exp(x), na.rm = TRUE)
}

## tanh: [-1, 1], similar to sigmoid, no need to sum 1
tanh <- function(x) 2 / (1 + exp(-2 * x)) - 1

utils::globalVariables(c(".dot", "Scores"))
