## ----setup, echo=FALSE, results="hide"------------------------------------------------------------
knitr::opts_chunk$set(
  tidy = FALSE,
  cache = TRUE,
  echo = FALSE,
  dev = c("png", "pdf"),
  package.startup.message = FALSE,
  message = FALSE,
  error = FALSE,
  warning = FALSE
  # collapse = TRUE,
  # comment = "#>",
  # fig.path = ""  # Added this line to the standard setup chunk)
)

options(width = 100)

## ----define.fxn, cache=FALSE----------------------------------------------------------------------
library(ggplot2)
library(crumblr)
library(HMP)
library(parallel)
library(glue)
library(tidyverse)

rmultinomdir <- function(n, size, alpha) {
  p <- rdirichlet(n, alpha)
  res <- lapply(seq(1, n), function(i) {
    rmultinom(1, size, prob = p[i, ])
  })
  res <- do.call(cbind, res)

  t(res)
}

run_sim <- function(prop_other = NULL, pseudocount = .5) {
  n_sims <- 1000
  j <- 1

  # Some thoughts on counts in sequencing studies. NAR G&B
  # doi: 10.1093/nargab/lqaa094
  df_res <- lapply(c(500, 5000), function(countTotal) {
    countsOther <- sum(prop_other) * countTotal

    # evaluate at set of equally spaced integers
    targetValues <- seq(1, max(countTotal - countsOther - 1, 1), length.out = 100)
    targetValues <- unique(targetValues)

    df_res <- mclapply(targetValues, function(targetCount) {
      df_res <- lapply(c(1, 5, 10), function(tau) {
        targetCount2 <- max(countTotal - targetCount - countsOther, 0)

        # create vector of alpha values
        if (!is.null(prop_other)) {
          expectedCounts <- c(targetCount, targetCount2, prop_other * countTotal)
        } else {
          expectedCounts <- c(targetCount, targetCount2)
        }

        p <- expectedCounts / sum(expectedCounts)

        if (tau == 1) {
          # large a0 corresponds to multinomial
          alpha <- p * 1e9
        } else if (tau < 1) {
          stop("tau must be greater than 1")
        } else {
          # convert tau overdispersion value to alpha from p
          a0 <- (countTotal - tau) / (tau - 1)
          alpha <- p * a0
        }

        # Multinomial-Dirichlet
        counts <- Dirichlet.multinomial(rep(countTotal, n_sims), alpha)

        keep <- colMeans(counts) > 2
        keep[1:2] <- TRUE
        counts <- counts[, keep, drop = FALSE]
        expectedCounts <- expectedCounts[keep]

        # transform
        x_clr <- clr(counts, pseudocount)

        # mu_sample_md = colMeans(x_clr)
        v_empirical <- apply(x_clr, 2, var)

        # expected variance based on *true* values
        D <- ncol(counts)
        p <- (expectedCounts + pseudocount) / sum(expectedCounts + pseudocount)
        v_theoretical_est <- tau * (1 / p - 2 / (p * D) + sum(1 / p) / D^2) / countTotal

        data.frame(
          targetCount = targetCount,
          countTotal = countTotal,
          tau = tau,
          countsOther = countsOther,
          v_empirical = v_empirical[j],
          v_theoretical_est = v_theoretical_est[j]
        )
      })
      do.call(rbind, df_res)
    }, mc.cores = 10)
    do.call(rbind, df_res)
  })
  df_res <- do.call(rbind, df_res)

  df_res
}

## ----sim1, cache=TRUE-----------------------------------------------------------------------------
df_res <- run_sim()

# "v_theoretical",
cols <- c("tau", "targetCount", "countTotal", "v_empirical", "v_theoretical_est")
df_v <- reshape2::melt(df_res[, cols], id.vars = cols[1:3])

# set facets
df_v$facet_x <- with(df_v, glue('tau*" = {tau}"'))
df_v$facet_x <- factor(df_v$facet_x, unique(df_v$facet_x))
df_v$facet_y <- with(df_v, glue('C*" = {countTotal}"'))
df_v$facet_y <- factor(df_v$facet_y, unique(df_v$facet_y))

## ----plot1, fig.height=6, fig.align = 'center'----------------------------------------------------
df_sub <- df_v %>%
  filter(targetCount >= 2 & (countTotal - targetCount) >= 2)

ggplot() +
  geom_point(data = filter(df_sub, variable == "v_empirical"), aes(targetCount / countTotal, sqrt(value), color = variable)) +
  geom_line(data = filter(df_sub, variable != "v_empirical"), aes(targetCount / countTotal, sqrt(value), color = variable), linewidth = 1) +
  theme_classic() +
  theme(aspect.ratio = 1, plot.title = element_text(hjust = 0.5), legend.position = "bottom") +
  facet_grid(facet_y ~ facet_x, labeller = label_parsed, scales = "free_y") +
  xlab("Proportion") +
  ylab("Standard Deviation") +
  ggtitle("Standard deviations from empirical DMN and asymptotic normal approximation") +
  scale_color_manual(name = "Method", values = c("blue4", "red"), labels = c("Empirical (DMN)", "Asymptotic normal approx.")) +
  scale_linetype_manual(name = "Method", values = c(1, 2), labels = c("Empirical (DMN)", "Asymptotic normal approx.")) +
  scale_y_continuous(limits = c(0, NA))

## ----sim2, cache=FALSE----------------------------------------------------------------------------
prop_other <- rep(.05, 15)
df_res <- run_sim(prop_other)

cols <- c("tau", "targetCount", "countTotal", "countsOther", "v_empirical", "v_theoretical_est")
df_v <- reshape2::melt(df_res[, cols], id.vars = cols[1:4])

# set facets
df_v$facet_x <- with(df_v, glue('tau*" = {tau}"'))
df_v$facet_x <- factor(df_v$facet_x, unique(df_v$facet_x))
df_v$facet_y <- with(df_v, glue('C*" = {countTotal}"'))
df_v$facet_y <- factor(df_v$facet_y, unique(df_v$facet_y))

## ----plot2, fig.height=6, cache=FALSE, fig.align = 'center'---------------------------------------
df_sub <- df_v %>%
  filter(targetCount >= 2 & (countTotal - targetCount - countsOther) >= 2)

ggplot() +
  geom_point(data = filter(df_sub, variable == "v_empirical"), aes(targetCount / countTotal, sqrt(value), color = variable)) +
  geom_line(data = filter(df_sub, variable != "v_empirical"), aes(targetCount / countTotal, sqrt(value), color = variable), linewidth = 1) +
  theme_classic() +
  theme(aspect.ratio = 1, plot.title = element_text(hjust = 0.5), legend.position = "bottom") +
  facet_grid(facet_y ~ facet_x, labeller = label_parsed, scales = "free_y") +
  xlab("Proportion") +
  ylab("Standard Deviation") +
  ggtitle("Standard deviations from empirical DMN and asymptotic normal approximation") +
  scale_color_manual(name = "Method", values = c("blue4", "red"), labels = c("Empirical (DMN)", "Asymptotic normal approx.")) +
  scale_linetype_manual(name = "Method", values = c(1, 2), labels = c("Empirical (DMN)", "Asymptotic normal approx.")) +
  scale_y_continuous(limits = c(0, NA))

## ----for.manuscript, eval=FALSE, cache=FALSE------------------------------------------------------
# df_sub <- df_v %>%
#   filter(targetCount >= 2 & (countTotal - targetCount - countsOther) >= 2) %>%
#   filter(tau == 10 & countTotal == 5000)
# 
# fig <- ggplot() +
#   geom_point(data = filter(df_sub, variable == "v_empirical"), aes(targetCount / countTotal, sqrt(value), color = variable)) +
#   geom_line(data = filter(df_sub, variable != "v_empirical"), aes(targetCount / countTotal, sqrt(value), color = variable), linewidth = 1) +
#   theme_classic() +
#   theme(aspect.ratio = 1, plot.title = element_text(hjust = 0.5), legend.position = c(.4, .7), legend.justification = c(0, 0)) +
#   xlab("Proportion") +
#   ylab("Standard Deviation") +
#   scale_color_manual(name = "Method", values = c("blue4", "red"), labels = c("Empirical (DMN)", "Asymptotic normal approx.")) +
#   scale_linetype_manual(name = "Method", values = c(1, 2), labels = c("Empirical (DMN)", "Asymptotic normal approx.")) +
#   ggtitle(expression(tau == 10 ~ C == 5000 ~ D == 15)) +
#   scale_y_continuous(limits = c(0, NA), expand = c(0, 0))
# ggsave("sim_figure.pdf", fig, height = 4, width = 4)

## ----session, echo=FALSE--------------------------------------------------------------------------
sessionInfo()