simulations

library(sensemakr)
library(MendelianRandomization)
library(MRPRESSO)
library(paramtest)
library(MRMix)
library(genius)
library(snow)

# function to find the critical k
search <- function(ks, model, treatment = "prs", benchmark_covariates = "x") {
  positive <- sign(coef(model)[treatment]) > 0
  get.ci <- ifelse(positive,"adjusted_lower_CI", "adjusted_upper_CI")
  out <- suppressWarnings(tryCatch({
    out <- ovb_bounds(model = model, treatment = treatment,
                      benchmark_covariates = benchmark_covariates, kd = ks)
    value <- out[[get.ci]]^2
    value
  }, error = function(e) Inf))
  return(out)
}

# function to create the data
create_data <- function(n = 1e5,
                        m = 90,
                        phi,
                        beta,
                        delta,
                        tau   = 0,
                        gamma = .1,
                        psi =1,
                        eta = .1, 
                        scale = TRUE){
  G <- rbinom(n*m, 2,prob = 1/3)
  dim(G) <- c(n,m)
  colnames(G) <- paste0("g",1:m)
  w  <- c(G %*% phi + rnorm(n))
  x  <- c(G %*% delta + rnorm(n))
  u  <- rnorm(n)
  d  <- c(G %*% beta + psi*w + psi*x + u + rnorm(n))
  y  <- c(tau*d + gamma*w + eta*x + u + rnorm(n))
  
  # scale data
  if(scale){
    w <- as.numeric(scale(w))
    x <- as.numeric(scale(x))
    d <- as.numeric(scale(d))
    y <- as.numeric(scale(y))
    G <- apply(G, 2, function(x)as.numeric(scale(x)))
  }
  
  data <- list(y = y,d = d,x = x,w = w, G= G, u = u)
  return(data)
}

# function to get the betas
get_betas <- function(y, d, G){
  md <- lm(d ~ G)
  my <- lm(y ~ G)
  coef.md <- coef(summary(md))
  coef.my <- coef(summary(my))
  bd   <- coef.md[-1,1]
  bdse <- coef.md[-1,2]
  by   <- coef.my[-1,1]
  byse <- coef.my[-1,2]
  data <- data.frame(bd, by, bdse, byse)
  return(data)
}


# Argument R1.a sets j variants to be valid
# Argument R1.b makes W "k times" stronger than X
sim <- function(nsim, n = 100, m= 90, gamma = 0.1, eta = gamma, tau = 0,  presso = FALSE,
                R1.a = F, j = 1, R1.b = F, k = 1, scale = TRUE){
  
  eta <- eta
  # initialize output
  out       <- data.frame(n = n)
  out$tau   <- tau
  out$m     <- m
  out$R1.a  <- R1.a
  out$j     <- j
  out$R1.b  <- R1.b
  out$k     <- k
  
  phi   <- runif(m, .01, .05)
  beta  <- runif(m, .01, .05)
  delta <- runif(m, .01, .05)
  
  if(R1.a){
    delta[1:j] <- 0
    phi[1:j] <- 0
  }
  
  
  if(R1.b){
    phi <- sqrt(k)*phi
    gamma <- sqrt(k)*gamma
  }
  
  # two samples of size n
  data1 <- create_data(n = n, phi = phi, beta = beta, delta = delta, tau = tau, gamma = gamma, eta = eta, scale = scale)
  b1    <- with(data1, get_betas(y, d, G))
  
  # memory
  rm(data1)
  gc()
  
  data2 <- create_data(n = n, phi = phi, beta = beta, delta = delta, tau = tau, gamma = gamma, eta = eta, scale = scale)
  b2    <- with(data2, get_betas(y, d, G))
  
  # compute everything that needs data2 than removes it
  # PRS
  data2$prs <- data2$G %*% b1$bd
  
  # MR Genius
  form <- paste0("A~", paste0(colnames(data2$G), collapse = "+"))
  form <- as.formula(form)
  output.genius <- genius_addY(Y = data2$y, A = data2$d, G = data2$G, formula = form)
  out$p.genius  <- output.genius$pval
  
  # RF sensitivity
  rf        <- with(data2, lm(y ~ prs + x))
  out$r2yx.z    <- partial_r2(rf)["x"]
  
  model.yw  <- with(data2, lm(y ~ prs + w))
  out$r2yw.z    <- partial_r2(model.yw)["w"]
  
  model.zx   <- with(data2, lm(prs ~ x))
  out$r2zx    <- partial_r2(model.zx)["x"]
  
  model.zw   <- with(data2, lm(prs ~ w))
  out$r2zw    <- partial_r2(model.zw)["w"]
  
  out$gamma <- gamma
  out$eta   <- eta
  
  
  # sanity checks
  true.model <- lm(y~d + w + x + u, data = data2)
  out$true.tau   <- coef(true.model)["d"]
  rf.full    <- lm(y ~ prs + x + w, data = data2)
  fs.full    <- lm(d ~ prs + x + w, data = data2)
  out$iv.tau     <- coef(rf.full)["prs"]/coef(fs.full)["prs"]
  
  # memory
  rm(data2)
  gc()
  
  # MR input
  input.mr   <- mr_input(by = b1$by, byse = b1$byse, bx = b2$bd, bxse = b2$bdse)
  
  
  # MR IVW
  output.ivw <- mr_ivw(input.mr)
  out$p.ivw  <- output.ivw@Pvalue
  
  # MR egger
  output.egger <- mr_egger(input.mr)
  out$p.egger  <- output.egger@Causal.pval
  
  # Other MR
  
  output.conmix <- mr_conmix(input.mr)
  out$p.conmix  <- output.conmix@Pvalue
  
  output.mbe   <- mr_mbe(input.mr)
  out$p.mbe    <- output.mbe@Pvalue
  
  output.median <- mr_median(input.mr)
  out$p.median  <- output.median@Pvalue
  
  # output.lasso <- mr_lasso(input.mr)
  # output.lasso@Pvalue
  
  if(presso){
    # MR Presso
    res           <- data.frame(by = b1$by, byse = b1$byse, bx = b2$bd, bxse = b2$bdse)
    output.presso <- mr_presso(BetaOutcome = "by", BetaExposure = "bx", SdOutcome = "byse", SdExposure = "bxse",
                               OUTLIERtest = T, DISTORTIONtest = F, data = res, NbDistribution = 1000,
                               SignifThreshold = 0.05)
    presso.p.values <- output.presso$`Main MR results`$`P-value`
    p.presso <- presso.p.values[2]
    if(is.na(p.presso)){
      p.presso <- presso.p.values[1]
    }
    out$p.presso <-  p.presso
  }
  
  #MR-Mix
  output.mix <- MRMix(betahat_x = b1$bd, betahat_y = b2$by, sx = b1$bdse, sy =b2$byse,profile = F)
  out$p.mix  <- output.mix$pvalue_theta
  
  ## sensemakr
  sense.out <- sensemakr(rf, treatment = "prs", benchmark_covariates = "x", kd = 1)
  rv        <- sense.out$sensitivity_stats[,"rv_qa"]
  r2yz      <- sense.out$sensitivity_stats[,"r2yd.x"]
  or_t      <- sense.out$sensitivity_stats$t_statistic
  adj_t     <- sense.out$bounds$adjusted_t
  s.t       <- sign(or_t)
  s.adj_t   <- sign(adj_t)
  
  # search k
  if(rv > 0){
    ks <- optimise(f = search, interval = c(0, 10), 
                   model = rf, 
                   treatment = "prs", 
                   benchmark_covariates = "x")$minimum
  } else {
    ks <- 0
  }
  
  if(s.t != s.adj_t){
    p.bench <- 1
  } else {
    p.bench <- pt(abs(adj_t), df = n, lower.tail = F)*2
  }
  out$p.bench  <- p.bench
  out$rv       <- rv
  out$r2yz     <- r2yz
  out$ks       <- ks
  out$or_t     <- or_t
  out$adj_t    <- adj_t
  return(out)
}


#### Use the code below to run the simulation 
#### Warning: it takes a long time (several days on a regular computer)
####          so simulate once to see how long it takes, and plan accordingly.

## test run
# system.time(test <- sim(1, n = 1e5, presso = T, gamma = 0.05, tau = 0))

## now simulate as many times as you want, with different parameter values
## system.time({
#   sim.out <- grid_search(sim, 
#                          n.iter = 1000, 
#                          params = list(n = c(150000, 300000, 450000), gamma = 0.05, tau = c(0, 0.1), presso = T), 
#                          parallel = "multicore", 
#                          ncpus = 32
#   )
# })