PRT <- function(data,
                covariates,
                cutoffs_pseudo,
                time_col, 
                status_col,
                id_col,
                D = 3,
                fold = 5,
                center = FALSE) {
  
  options(warn = -1)
  
  seed <- 1234
  set.seed(seed)
  
  
  #Prepare data set 
  n_time <<- length(cutoffs_pseudo)
  n_cov  <<- length(covariates)
  
  # check if covariate is factor
  is_factor <<- sapply(covariates, FUN = check_factor, data = data)
  
  # continuous covariates will be mean centered if center = TRUE
  if(center) {
    cont_vars        <- covariates[! is_factor]
    data[,cont_vars] <- apply(data[,cont_vars], MARGIN = 2, FUN = scale, center = TRUE, scale = FALSE)
  }
  
  
  # set new colnames for further usage
  data <- data %>% 
    add_column(Ttilde   = unlist(data[,time_col]),
               status   = unlist(data[,status_col]),
               id       = unlist(data[,id_col]))

  
  print("Calculate pseudo-values...")
  
  # Calculate pseudo-values on training data 
  pseudo_values <- pseudosurv(time = data$Ttilde, event = data$status, tmax = cutoffs_pseudo)
  
  
  # Attach pseudo-values to data
  data_pseudo <- cbind.data.frame(data, pseudo_values$pseudo) %>%
    pivot_longer(cols = (ncol(data) + 1):(ncol(data) + length(cutoffs_pseudo)), names_to = "time", values_to = "psOrig") %>%
    mutate(time     = as.numeric(sapply(time, FUN = function(x){return(strsplit(x, split = ".", fixed = TRUE)[[1]][2])})),
           time2    = as.factor(time),
           x0       = 1,
           id2      = 1:nrow(.)) 
  
  
  # Tune mstop via cross-validation and use this mstop for complete model building
  PRT <- tune_mstop(data_pseudo = data_pseudo, fold = fold, 
                    n_time = n_time, cutoffs_pseudo = cutoffs_pseudo, 
                    D = D, r = seed) 
  
  
  PRT$center     <- center
  PRT$covariates <- covariates
  
  return(PRT)
  
}


# to check if covariate is factor or not (for centering, if applicable)
check_factor <- function(cov, data) {
  col <- unlist(data[,cov])
  return(is.factor(col))
}


# new family mboost
GaussCloglog <- function () {
  Family(ngradient = function(y, f, w = 1) (y - (1-exp(-exp(-f))) ) *
           -exp(-exp(-f)) * exp(-f),
         loss = function(y, f) (y - (1-exp(-exp(-f))) )^2,
         offset = function(y, w) 0, check_y = function(y) {
           if (!is.numeric(y) || !is.null(dim(y)))
             stop("response is not a numeric vector but ",
                  sQuote("family = Gaussian()"))
           y}, name = "Squared Error (Regression) with cloglog link",
         fW = function(f) return(rep(1, length = length(f))), response = function(f) f)
}



# cross validation
tune_mstop <- function(data_pseudo, fold = 5, n_time, 
                       cutoffs_pseudo, r, D = 3) {
  
  
  
  mstop         <- c(300, 400)
  stepsize      <- 100
  maximal_mstop <- 10000
  min_change    <- 0.1
  
  
  # function to grow the tree
  build_tree <- function(d_train, data_test = NULL, D, n_t) {
    
    # bring data into appropriate data format
    d_tmp   <- d_train %>% dplyr::select(-c(time, id2, id, status, Ttilde)) %>% pivot_wider(values_from = psOrig, names_from = time2)
    colnames(d_tmp)[(ncol(d_tmp)-n_time+1):(ncol(d_tmp))] <- c(paste0("y", 1:length(cutoffs_pseudo)))
    
    # formula for tree growing
    formula_ctree <- as.formula(paste0(paste0("y", 1:n_time, collapse = " + "), " ~ ", paste0(covariates, collapse = " + ")))
    
    # tree growing
    tree2 <- partykit:::ctree(formula = formula_ctree,
                              data = d_tmp, 
                              pargs = GenzBretz(abseps = .1, releps = .1),
                              control = partykit:::ctree_control(alpha = 0.05, maxdepth = D, minbucket = 5*n_t, 
                                                                 teststat = "max", testtype = c("MonteCarlo"),
                                                                 splitstat = "max"))
    
    
    
    # extract parent node for each node and save the information
    node_list <- as.list(tree2$node)
    nnodes    <- length(node_list)
    
    parents <- c()
    for(n in 1:nnodes) {
      node          <- node_list[[n]]
      kids          <- node$kids
      parents[kids] <- node$id
    }
    
    if(length(parents != 0)) {
      parent_list <- cbind.data.frame(node = 1:nnodes, parent = parents) 
    } else {
      parent_list <- cbind.data.frame(node = 1:nnodes, parent = NA) 
    }
    
    
    
    # attach node level to node list
    node_level <- function(n) {
      level <- 0
      while(! is.na(n)) {
        n           <- parent_list$parent[parent_list$node == n]
        level       <- level + 1
      }
      
      return(level)
    }
    
    parent_list$level    <- sapply(1:nnodes, FUN = node_level)
    parent_list$terminal <- parent_list$level == max(parent_list$level)
    
    # extract paths for each node
    paths <- list()
    for(n in 1:nnodes) {
      path_n       <- c()
      current_node <- n
      while(! is.na(current_node)) {
        path_n       <- c(path_n, parent_list$parent[parent_list$node == current_node])
        current_node <- path_n[length(path_n)]
      }
      
      paths[[n]] <- c(rev(path_n[! is.na(path_n)]), n) # from top to bottom order
    }
    
    
    # get paths for each observation
    d_train <- d_train %>% 
      mutate(end_node = predict(tree2, newdata = d_train, type = "node")) %>%
      mutate(path = sapply(end_node, function(n){return(paste0(paths[[n]], collapse = ","))})) 
    
    
    # create and save subsets for each node
    in_node <- function(n) {
      sub <- d_train %>% filter(sapply(sapply(path, function(p){c(strsplit(p, split = ","))}), function(x){n %in% x}))
    }
    
    node_data <- lapply(1:nnodes, FUN = in_node)
    
    list_return <- list(tree   = tree2,
                        nnodes = nnodes,
                        paths  = paths,
                        parent_list = parent_list,
                        node_data = node_data,
                        node_list = node_list)
    
    
    
    return(list_return)
  }
  
  
  cv_complete_tree <- function(m_stop, tree_obj, data_train, 
                               data_test, complete = FALSE) {
    
    
    # function to make predictions on the model tree
    predict_PRT_cv <- function(newdata, tree_obj, model_list, offset_list) {
      
      n_time <- length(unique(newdata$time))
      
      sub_fun <- function(i) {
        new_x              <- newdata_sorted %>% filter(id == i)
        end_node           <- unique(new_x$end_node)
        
        
        tmp  <- as.matrix(predict(model_list[[end_node]], newdata = new_x, which = 1:(n_cov+2)))
        if(! is.null(attr(tmp, "offset"))) {
          pred <- apply(tmp, 1, sum) + offset_list[[end_node]][1:n_time]
        }
        else {
          pred <- apply(tmp, 1, sum)
        }
        
        if(length(pred) < n_time) { 
          # given that all estimates are zero in an end node
          pred <- rep(pred, n_time)
        }
        
        
        
        return(list(pred = pred, end_node = end_node))
        
      }
      
      
      tree        <- tree_obj$tree
      nnodes      <- tree_obj$nnodes
      parent_list <- tree_obj$parent_list
      node_data   <- tree_obj$node_data
      
      # sort newdata 
      newdata_sorted  <- newdata %>% 
        arrange(id) %>%
        mutate(end_node = as.numeric(predict(tree, newdata = ., type = "node")))
      
      ids_unique <- unique(newdata_sorted$id)
      
      
      
      
      predictions   <- end_nodes <- c()
      for(j in ids_unique) {
        tmp         <- sub_fun(j)
        pred        <- tmp$pred
        end_node    <- tmp$end_node
        predictions <- c(predictions, pred)
        end_nodes   <- c(end_nodes, rep(end_node, n_time))
      }
      
      
      predictions_sorted <- predictions[match(newdata_sorted$id2, newdata$id2)]
      
      pred <- newdata %>%
        dplyr::select(time, psOrig) %>%
        mutate(pred          = 1-exp(-exp(-predictions_sorted)),
               end_node      = end_nodes)
      
      return(pred)
      
    }
    
    
    bols_cov <- cbind.data.frame(covariates = c("x0", covariates), is_factor = c(FALSE, is_factor)) %>%
      mutate(bols_cov = ifelse(is_factor, paste0("bols(", covariates, ")"), paste0("bols(", covariates, ", intercept = FALSE)")))
    
    
    formula_nodewise_boosting <- as.formula(paste0("psOrig ~ ", paste0(bols_cov$bols_cov, collapse = " + "), " + bmono(time, constraint = 'inc', intercept = FALSE, df = 4)"))
    
    
    
    tree        <- tree_obj$tree
    nnodes      <- tree_obj$nnodes
    parent_list <- tree_obj$parent_list
    node_data   <- tree_obj$node_data
    
    n_train     <- nrow(data_train)
    
    # mstops for all nodes 
    node_sizes <- sapply(1:tree_obj$nnodes, function(n) {nrow(tree_obj$node_data[[n]])})
    
    mstop_vec <- round(node_sizes/nrow(data_train)*m_stop)
    
    
    
    # to save models and offsets 
    prediction_list   <- model_list <- cov_sub <- list()
    offset_list       <- vector("list", nnodes)
    
    
    mean_ps <- data_train %>%
      group_by(time) %>%
      summarise(mean = mean(psOrig))
    
    offset <- as.numeric(rep(mean_ps$mean, length(unique(data_train$id))))
    
    offset_list[[1]] <- -log(-log(1-offset))
    
    # model on all data (node 1, root)
    model  <- gamboost(formula_nodewise_boosting, data = data_train, 
                       control = boost_control(mstop = m_stop, nu = 0.1),
                       family = GaussCloglog(), offset = offset_list[[1]]) 
    
    
    
    fitted_values        <- fitted(model)
    prediction_list[[1]] <- cbind.data.frame(id2 = data_train$id2, pred = fitted_values)     
    model_list[[1]]      <- model
    
    
    if(nnodes >= 2) {
      # models on each following node
      for(n in 2:nnodes) {
        
        set.seed(n)
        
        # determine parent node of the current node
        parent_node <- parent_list$parent[parent_list$node == n]
        
        # subset data to data available in the respective node
        sub_data    <- node_data[[n]] %>% 
          filter(id2 %in% data_train$id2)
        
        # build models in the single nodes 
        # check if covariate has more than one level in the nodes for factors
        is_in <- function(covariate) {
          
          if(covariate == "x0") {
            return(TRUE)
          }
          
          sub <- sub_data %>% dplyr::select(all_of(covariate)) %>% unlist()
          
          
          if(! is.factor(sub)) {
            
            if(sd(sub) == 0) {
              return(FALSE)
            } else {
              return(TRUE)
            }
            
          } else {
            if(any(table(sub) == 0)) {
              return(FALSE)
            } else {
              return(TRUE)
            } 
          }
          
          
        }
        
        bols_cov_sub <- cbind.data.frame(covariates = c("x0", covariates), is_factor = c(FALSE, is_factor)) %>%
          mutate(bols_cov = ifelse(is_factor, paste0("bols(", covariates, ")"), paste0("bols(", covariates, ", intercept = FALSE)")),
                 is_in    = sapply(covariates, FUN = is_in))
        
        cov_sub[[n]] <- bols_cov_sub
        
        formula_mboost_tree <- as.formula(paste0("psOrig ~ ", paste0(bols_cov$bols_cov[bols_cov_sub$is_in], collapse = " + "), " + bmono(time, constraint = 'inc', intercept = FALSE, df = 4)"))
        
        
        # determine level of the node
        level <- parent_list$level[parent_list$node == n]
        
        # obtain offset (from parent node)
        parent_preds <- prediction_list[[parent_node]] %>%
          filter(id2 %in% sub_data$id2) %>%
          mutate(id   = sub_data$id, 
                 time = rep(cutoffs_pseudo, nrow(.)/length(cutoffs_pseudo))) 
        
        
        A <- parent_preds %>%
          pivot_wider(., id_cols = id, names_from = time, values_from = pred) 
        
        .tmp   <- apply(A[2:(n_time+1)], MARGIN = 2, FUN = mean)
        offset <- as.numeric(rep(.tmp, length(unique(sub_data$id))))
        
        
        
        # build the model using the offset defined above
        model  <- gamboost(formula_mboost_tree, data = sub_data, 
                           control = boost_control(mstop = mstop_vec[n], nu = 0.1),
                           family = GaussCloglog(), offset = offset)
        
        
        
        
        # save offset
        offset_list[[n]] <- offset 
        
        # save new predictions
        fitted_values        <- fitted(model)
        prediction_list[[n]] <- cbind.data.frame(id2 = sub_data$id2, pred = fitted_values) 
        
        # save model 
        model_list[[n]] <- model
        
        
        
      }
    }
    
    if(complete) {
      
      return(list(model_list = model_list, MSE = NULL, offset_list = offset_list, mstop_opt = m_stop))
      
    } else {
      
      
      # predict on test data leaved out for that cv iteration and calculate MSE
      pred_gamboost_treeT  <- predict_PRT_cv(data_test, tree_obj = tree_obj, model_list = model_list, offset_list = offset_list)
      
      
      
      
      node_loss <- function(node, pred_data) {
        A <- pred_data %>%
          filter(end_node == node) 
        
        A$psOrig - A$pred
        
        loss_t <- return(sum((A$psOrig - A$pred)^2))
        
      }
      
      end_nodes        <- unique(pred_gamboost_treeT$end_node)
      loss             <- sapply(end_nodes, FUN = node_loss, pred_data = pred_gamboost_treeT)
      weights_nodes    <- as.numeric(table(pred_gamboost_treeT$end_node))/nrow(pred_gamboost_treeT)
      
      
      loss_t_mean  <- weighted.mean(loss, w = weights_nodes)
      loss_sum     <- sum(loss_t_mean)
      
      return(list(model_list = model_list, MSE = loss_sum, offset_list = offset_list))
      
      
      
    }
    
    
  }
  
  # function to perform cross-validation
  PRT_cv <- function(i, mstop, tree_obj_list) {
    
    

    
    tree_obj <- tree_obj_list[[i]]
    
    # Fit models on each node and tune the tree 
    modeltree_obj <- lapply(mstop, FUN  = cv_complete_tree, tree_obj = tree_obj, 
                            data_train  = data_train_list_aug[[i]], 
                            data_test   = data_test_tree_list_aug[[i]])
    MSE           <- sapply(modeltree_obj, FUN = function(x){x$MSE})
 
    
    return(MSE)
  }
  
  
  set.seed(123*r)
  
  # partition data for cross-validation
  unique_ids <- unique(data_pseudo$id)
  chunks     <- chunk(unique_ids, n.chunks = fold)
  
  data_pseudo_aug <- data_pseudo %>%
    mutate(cv = rep(chunks, each = n_time))
  
  data_train_list_aug <- data_test_tree_list_aug <- tree_obj_list <- list()
  for(i in 1:fold) {
    
    
    # subset data into training and test data for cv
    data_train_list_aug[[i]]    <- data_pseudo_aug %>% filter(cv != i) %>%
      dplyr::select(c(-time, -time2, -psOrig, -id2, -cv)) %>% 
      distinct() %>%
      cbind.data.frame(., pseudosurv(time = .$Ttilde, event = .$status, tmax = cutoffs_pseudo)$pseudo) %>%
      pivot_longer(cols = c((ncol(.)-n_time + 1):ncol(.)), names_to = "time", values_to = "psOrig") %>%
      mutate(time  = rep(cutoffs_pseudo, nrow(.)/n_time),
             time2 = as.factor(time),
             id2   = as.numeric(id) + rep(c(0:(n_time-1)*nrow(.)), nrow(.)/n_time))
    
    
    data_test_tree_list_aug[[i]]  <- data_pseudo_aug %>% filter(cv == i) %>%
      dplyr::select(c(-time, -time2, -psOrig, -id2, -cv)) %>% 
      distinct() %>%
      cbind.data.frame(., pseudosurv(time = .$Ttilde, event = .$status, tmax = cutoffs_pseudo)$pseudo) %>%
      pivot_longer(cols = c((ncol(.)-n_time + 1):ncol(.)), names_to = "time", values_to = "psOrig") %>%
      mutate(time  = rep(cutoffs_pseudo, nrow(.)/n_time),
             time2 = as.factor(time),
             id2   = as.numeric(id) + rep(c(0:(n_time-1)*nrow(.)), nrow(.)/n_time))
    
    
    # grow tree on training data for each cv repetition
    tree_obj_list[[i]]   <- build_tree(d_train = data_train_list_aug[[i]], D = D, n_t = n_time)
    
    
    
  }
  
  # optimzation of mstop
  print("Optimize mstop via cross-validation...")
  
  MSE_matrix    <- sapply(1:fold, FUN = PRT_cv, mstop = mstop, tree_obj_list = tree_obj_list)
  mean_MSE      <- apply(MSE_matrix, FUN = mean, MARGIN = 1)
  ind_min       <- which.min(mean_MSE)
  mstop_opt     <- mstop[ind_min]
  MSE_min       <- mean_MSE[ind_min]
  
  
  
  # if mstop_opt is on the right bound of given mstops
  max_mstop <- max(mstop)
  min_mstop <- min(mstop)


  

  
  while(mstop_opt == max_mstop) {
    
    mstop  <- c(max_mstop + stepsize)
    tmp    <- sapply(1:fold, FUN = PRT_cv, mstop = mstop, tree_obj_list = tree_obj_list)
    MSE    <- mean(tmp)
    
    
    
    # break if min_change is reached
    if((MSE_min - MSE < min_change) & (MSE_min - MSE > 0)) {
      mstop_opt     <- mstop
      break
    }
    
    # break if maximal mstop is reached
    if(mstop == maximal_mstop) {
      mstop_opt     <- mstop
      break
    }
    
    if(MSE < MSE_min) {
      mstop_opt     <- mstop
      MSE_min       <- MSE
      max_mstop     <- mstop
    } else {
      mstop_opt     <- max_mstop
      break
    }
    
  }
  
  # if mstop_opt is on the left bound of given mstops
  while(mstop_opt == min_mstop) {
    
    mstop  <- max(0, (min_mstop - stepsize))
    tmp    <- sapply(1:fold, FUN = PRT_cv, mstop = mstop, tree_obj_list = tree_obj_list)
    MSE    <- mean(tmp)
    
    
    
    # break if min_change is reached
    if((MSE_min - MSE < min_change) & (MSE_min - MSE > 0)) {
      mstop_opt     <- mstop
      break
    }
    
    
    
    if(MSE < MSE_min) {
      mstop_opt     <- mstop
      MSE_min       <- MSE
      min_mstop     <- mstop
    } else {
      mstop_opt     <- min_mstop
      break
    }
  }
  
  
  # take best mstop and fit model on complete tree
  print(paste0("Fit PRT with optimal mstop on all data..."))
  
  # build tree on all data 
  tree_obj   <- build_tree(d_train = data_pseudo, D = D, n_t = n_time)
  
  # node-wise gradient boosting on tree with all data
  node_boost <- cv_complete_tree(m_stop = mstop_opt, tree_obj = tree_obj, 
                                 data_train = data_pseudo_aug, data_test = NULL, complete = TRUE)
  
  
  return(list(tree = tree_obj, node_boost = node_boost))
  
  
}




# function to get predictions fpr (potentially) new data 
predict.PRT <- function(PRT_object, newdata = NULL) {
  
  sub_fun <- function(i) {
    #print(i)
    new_x              <- newdata_sorted %>% filter(id == i)
    end_node           <- unique(new_x$end_node)
    
    #new_x[colnames(new_x) %in% c(paste0("x", 1:n_cov))] = new_x[colnames(new_x) %in% c(paste0("x", 1:n_cov))] + centers[[end_node]]          
    
    
    
    tmp  <- as.matrix(predict(model_list[[end_node]], newdata = new_x, which = 1:(n_cov+2)))
    if(! is.null(attr(tmp, "offset"))) {
      pred <- apply(tmp, 1, sum) + offset_list[[end_node]][1:n_time]
    }
    else {
      pred <- apply(tmp, 1, sum)
    }
    
    if(length(pred) < n_time) { 
      # given that all estimates are zero in an end node
      pred <- rep(pred, n_time)
    }
    
    
    
    return(list(pred = pred, end_node = end_node))
    
  }
  
  tree_obj    <- PRT_object$tree
  tree        <- tree_obj$tree
  model_list  <- PRT_object$node_boost$model_list
  offset_list <- PRT_object$node_boost$offset_list
  mstop_opt   <- PRT_object$node_boost$mstop_opt
  center      <- PRT_object$center
  nnodes      <- tree_obj$nnodes
  parent_list <- tree_obj$parent_list
  node_data   <- tree_obj$node_data
  
  # if null, get estimates
  if(is.null(newdata)) {
    newdata_long <- PRT_object$tree$node_data[[1]]
  } else {
    is_factor <<- sapply(covariates, FUN = check_factor, data = newdata)
    
    if(center) {
      cont_vars        <- covariates[! is_factor]
      data[,cont_vars] <- apply(data[,cont_vars], MARGIN = 2, FUN = scale, center = TRUE, scale = FALSE)
    }
    
    
    # set new colnames for further usage
    
    
    
    newdata_long <- cbind.data.frame(newdata, matrix(0, ncol = 5, nrow = nrow(newdata))) %>%
      pivot_longer(cols = (ncol(newdata) + 1):(ncol(newdata) + length(cutoffs_pseudo)), names_to = "time") %>%
      mutate(time     = rep(cutoffs_pseudo, nrow(newdata)),
             time2    = as.factor(time),
             x0       = 1,
             id2      = 1:nrow(.),
             id       = rep(1:nrow(newdata), each = length(cutoffs_pseudo))) %>%
      dplyr::select(-c("value"))
  }
  
  
  

  
  n_time <- length(unique(newdata_long$time))


  
  # sort newdata 
  newdata_sorted  <- newdata_long %>% 
    arrange(id) %>%
    mutate(end_node = as.numeric(predict(tree, newdata = ., type = "node")))
  
  ids_unique <- unique(newdata_sorted$id)
  
  
  
  
  predictions   <- end_nodes <- c()
  for(j in ids_unique) {
    tmp         <- sub_fun(j)
    pred        <- tmp$pred
    end_node    <- tmp$end_node
    predictions <- c(predictions, pred)
    end_nodes   <- c(end_nodes, rep(end_node, n_time))
  }
  
  
  predictions_sorted <- predictions[match(newdata_sorted$id2, newdata_long$id2)]
  #end_nodes_sorted   <- end_nodes[match(newdata_sorted$id2, newdata$id2)]
  
  pred <- newdata_long %>%
    mutate(pred          = 1-exp(-exp(-predictions_sorted)),
           end_node      = end_nodes)
  
  
  return(pred)
  
}


get.baselearner <- function(PRT_object) {

  
  
  get_bl_sub <- function(j, estimates = FALSE) {
    
    c <- coef(j)
    
    to_check <- c(covariates, "time", "x0")
    
    bl_in  <- c()
    values <- list()
    for(i in 1:length(to_check)) {
      
      #print(i)
      
      cov   <- to_check[i]
      is_in <- c()
      for(k in 1:length(c)) {
        is_in[k] <- any(str_detect(names(c[[k]]), cov))
      }
      
      if(any(is_in)) {
        bl_in[i] <- 1
      } else {
        bl_in[i] <- 0
      }
      
      
    }
    
    #names(values) <- covariates
    #j <- model_list[[2]]
    if(estimates) {
      coef_list        <- coef(j)
      names(coef_list) <- sapply(1:length(names(coef_list)), 
                                 FUN = function(x) {
                                   gsub("[\\(\\)]", "", regmatches(names(coef_list)[x], 
                                                                   gregexpr("\\(.*?\\)", names(coef_list)[x]))[[1]]) 
                                 })
      return(coef_list)
    } else {
      return(bl_in)
    }
    
    
    
    
    
  }
  
  
  model_list  <- PRT_object$node_boost$model_list
  covariates  <- PRT_object$covariates
  
  
  
  baselearner           <- as.data.frame(matrix(unlist(lapply(model_list, FUN = get_bl_sub, estimates = FALSE)), ncol = length(covariates) + 2, byrow = TRUE))
  colnames(baselearner) <- c(covariates, "time", "x0")
  rownames(baselearner) <- paste0("Node ", 1:(length(model_list)))
  
  baselearner_est        <- lapply(model_list, FUN = get_bl_sub, estimates = TRUE)
  names(baselearner_est) <- paste0("Node ", 1:(length(model_list)))
  
  return(list(selected_bl  = baselearner, 
              estimates_bl = baselearner_est))
  
}





# function to analyze the split variables on the grown tree
analyze.splits <- function(PRT_object) {
  
  # extract relevant entries from PRT object
  tree_obj     <- PRT_object$tree
  covariates   <- PRT_object$covariates
  node_list    <- tree_obj$node_list
  paths        <- tree_obj$paths
  

  # extract end nodes from tree
  end_nodes <- as.numeric(na.exclude(sapply(node_list, FUN = function(n) {ifelse(is.null(n$kids), n$id, NA)})))
  
  # get splits for each path to end nodes 
  split_matrix           <- as.data.frame(matrix(0, ncol = length(end_nodes), nrow = length(covariates))) 
  rownames(split_matrix) <- covariates
  for(n in 1:length(end_nodes)) {
    path             <- paths[[end_nodes[n]]]
    path_node_list   <- node_list[path] 
    splits           <- covariates[unlist(sapply(path_node_list, function(x) {x$split$varid - 5}))] # , levels = key_splits$number))
    split_matrix[,n] <- (rownames(split_matrix) %in% splits)*1
  }
  
  colnames(split_matrix) <- end_nodes
  

  
  # count splits per covariate
  splits        <- rep(0, n_cov)
  node_split    <- c()
  level_split   <- c()
  lbls          <- c()
  names(splits) <- covariates
  for(n in 1:length(node_list)) {
    split              <- node_list[[n]]$split$varid - 5
    splits[split]      <- splits[split] + 1
    
    if(length(split) != 0) {
      node_split         <- c(node_split, n)
      level_split        <- c(level_split, tree_obj$parent_list$level[tree_obj$parent_list$node == n])
      lbls               <- c(lbls, names(splits)[split])
    }

  }
  
  splits_return <- cbind.data.frame(node           = node_split, 
                                    split_variable = lbls, 
                                    level          = level_split) %>%
    arrange(node)

  

  
  return(list(splits_per_endnode = split_matrix, splits_list = splits_return))
  
  
}



extract_split_information <- function(PRT, data_non_centered = NULL, covariates) {
  
  library(qdapRegex)
  library(stringi)
  
  edge_labels <- partykit:::.list.rules.party(PRT$tree, i = nodeids(PRT$tree))
  
  lab  <- c()
  for(i in 2:length(edge_labels)) {
    #print(i)
    sp            <- str_split(edge_labels[i], pattern = " & ")
    tmp           <- sp[[1]][length(sp[[1]])]
    
    if(str_detect(tmp, "RANDOM")) {
      
      sp2           <- str_split(tmp, pattern = " %in% ")
      cov           <- sp2[[1]][1]
      tmp2    <- str_split(tmp, pattern = " %in% ")
      tmp3    <- tmp2[[1]][2]
      tmp4    <- rm_between(tmp3, '"', '"', extract = TRUE)[[1]]
      if(length(tmp4) == 1) {
        lab[i]  <- tmp4
      } else{
        n_char  <- nchar(tmp4) 
        if(any(n_char > 20)) {
          index       <- which(n_char > 20)
          set_break   <- floor(length(strsplit(tmp4[index], " ")[[1]])/2)
          where_break <- as.numeric(gregexpr(" ", tmp4[index])[[1]])[set_break]
          l           <- tmp4[index] 
          stri_sub(tmp4[index], where_break, where_break) <- "\n"
          lab[i]      <- paste0(tmp4, collapse = ",\n")
        } else {
          lab[i]  <- paste0(tmp4, collapse = ", ")
        }
        
      }
      
    } else {
      where <- c()
      for(c in 1:length(covariates)) {
        where[c] <- str_detect(tmp, pattern = covariates[c])
      }
      split_cov <- covariates[where]

      relation_in <- str_detect(tmp, c("<=", ">"))
      
      if(any(relation_in) & (! check_factor(split_cov, data))) {
        pattern       <- c("<= ", "> ")[relation_in]
        sp2           <- str_split(tmp, pattern = pattern)
        cov           <- sp2[[1]][1]
        if(! is.null(data_non_centered)) {
          mean_cov      <- unlist(data_non_centered %>% dplyr::select(gsub(" ", "", cov))) %>% mean()
          break_point   <- as.character(round(as.numeric(sp2[[1]][2]) + mean_cov, 2))
        } else {
          break_point   <- as.character(round(as.numeric(sp2[[1]][2]), 2))
        }
        
        lab[i]        <- paste0(pattern, break_point)
      } else {
        sp2           <- str_split(tmp, pattern = " %in% ")
        cov           <- sp2[[1]][1]
        tmp2    <- str_split(tmp, pattern = " %in% ")
        tmp3    <- tmp2[[1]][2]
        tmp4    <- rm_between(tmp3, '"', '"', extract = TRUE)[[1]]
        if(length(tmp4) == 1) {
          lab[i]  <- tmp4
        } else{
          n_char  <- nchar(tmp4) 
          if(any(n_char > 20)) {
            index       <- which(n_char > 20)
            set_break   <- floor(length(strsplit(tmp4[index], " ")[[1]])/2)
            where_break <- as.numeric(gregexpr(" ", tmp4[index])[[1]])[set_break]
            l           <- tmp4[index] 
            stri_sub(tmp4[index], where_break, where_break) <- "\n"
            lab[i]      <- paste0(tmp4, collapse = ",\n")
          } else {
            lab[i]  <- paste0(tmp4, collapse = ", ")
          }
          
        }
        
      }
    }
    
    
    
    
  }
  
  return(lab)
  
}



plot.PRT <- function(PRT_object, covariate_labels, pred, baselearner, covariate_labels_bl, work_data, plot.title,
                     label_size_bl = 4, left_bl  = -0.08, right_bl = 0.08, y_bl = -0.015, y_surv = -0.07,
                     y_bl_terminal   = -0.17, height_terminal = 0.5, width_terminal  = 0.7, height_bl = 0.5,
                     width_bl = 0.7, ylim_surv = c(25,100), data_non_centered = NULL) {
  
  options(warn = -1)
  
  # for tree
  PRT        <- PRT_object$tree
  nnodes     <- length(PRT$node_list)
  split_list <- list()
  for(n in 1:nnodes) {
    if(is.null(PRT$node_list[[n]]$kids)) {
      split_list[[n]] <- NULL
    } else {
      split_list[[n]] <- PRT$node_list[[n]]$split
    }
    
  }
  
  length(split_list) <- nnodes
  
  kids <- list() 
  for(n in 1:nnodes) {
    kids[[n]] <- c(na.omit(PRT$parent_list$node[PRT$parent_list$parent == n]))
    
    if(identical(kids[[n]], integer(0))) {
      kids[[n]] <- NULL
    }
  }
  
  length(kids) <- nnodes
  
  
  
  
  kid_list <- list()
  my_partynode_kids <- function(i) {
    
    kid_nodes <- kids[[i]]
    
    if(is.null(kid_nodes)) {
      return(NULL) 
    } else {
      
      #kid_list <- list()
      for(k in kid_nodes) {
        kid_list[[k]] <- partynode(k, split = split_list[[k]], kids = my_partynode_kids(k), info = NULL)
      }
      
      return_list <- kid_list[kid_nodes]
      
      return(return_list)
    }
    
    
    
    
  }
  
  
  
  kids <- my_partynode_kids(1)
  
  
  pn <- partynode(1, split = split_list[[1]], kids = kids)
  

  #colnames(dummy_data) <- c("id", paste0("V", 1:nnodes))
  k <- 1
  while (k <= length(cutoffs_pseudo)) {
    work_data <- work_data %>%
      add_column(xyz_dummy = rnorm(nrow(.)), .before = 1) 
    
    k <- k + 1
  }

  
  
  
  py <- party(pn, data)
  
  # first line node labels
  node_labels1 <- paste0("Node ", 1:nnodes)
  
  
  
  # number of observations as node labels 
  get_number_of_observations <- function(node) {
    if(PRT$parent_list$terminal[PRT$parent_list$node == node]) {
      n_obs <- sum(predict(PRT$tree, newdata = work_data, type = "node") == node)
    } else {
      n_obs <- PRT$node_list[[node]]$info$nobs
    }
    
  }
  
  n_obs <- unlist(sapply(1:nnodes, FUN = get_number_of_observations))
  
  node_labels1.1 <- paste0("(n = ", n_obs, ")")
  
  # second line node labels
  # default
  
  A <- sapply(1:nnodes, function(x) covariates[split_list[[x]]$varid - 5])
  
  
  A[sapply(A, FUN = function(x) identical(x, character(0)))] <- " "
  node_labels2 <- unlist(A)
  
  # customized (in function: covariate_labels)
  
  PRT$parent_list <- PRT$parent_list %>%
    mutate(terminal = ! node %in% parent)
  
  if(! is.null(covariate_labels)) {
    node_labels2 <- str_replace_all(node_labels2, covariate_labels)
  }
  

  work_data$end_node <- predict(PRT$tree, work_data, type = "node")
  
  
  node_labels3 <- c()
  for(i in 1:length(node_labels2)) {
    .tmp <- unlist(str_split(node_labels2[i], pattern = "\n"))
    if(length(.tmp) == 2) {
      node_labels2[i] <- .tmp[1]
      node_labels3[i] <- .tmp[2]
    } else if(length(.tmp) == 1 & PRT$parent_list$terminal[i]) {
      node_labels2[i]   <- node_labels1.1[i]
      node_labels3[i]   <- " "
      node_labels1.1[i] <- node_labels1[i]
      node_labels1[i]   <- " "
    } else if(length(.tmp) == 1 & ! PRT$parent_list$terminal[i]) {
      node_labels3[i] <- " "
    }
  }
  
  


  # edge labels 
  edge_labels <- extract_split_information(PRT, data_non_centered, covariates = PRT_object$covariates)
  
  
  # data for node plots (only terminal nodes)
  predictions <- pred %>%
    dplyr::select(pred, time, end_node) %>% 
    mutate(pat_id = rep(work_data$pid, each = 5), 
           x = rep(1:length(unique(time)), nrow(.)/length(unique(time)))) %>%
    rename("id" = "end_node")
  
  
  predictions_mean <- predictions %>%
    group_by(x, id) %>%
    summarise(mean_S = mean(pred))
  
  
  
  # data for baselearner plot 
  bl <- baselearner %>%
    add_column(id = 1:nnodes) %>%
    pivot_longer(cols = -id, names_to = "Baselearner", values_to = "Selected") %>%
    filter(Baselearner %in% c(covariates, "time")) 
  
  
  
  bl$Baselearner <- str_replace_all(bl$Baselearner, covariate_labels_bl)
  
  bl <- bl %>%
    mutate(Baselearner = factor(Baselearner, levels = c(sort(covariate_labels_bl), "time")))
  #bl$Baselearner <- factor(bl$Baselearner, levels = c(sort(covariate_labels_bl), "time"))
  

  
  

  
  # position of plots 
  position_bl_plot <- function(node) {
    
    if(node == 1) {
      pos <- "one"
    } else if(PRT$parent_list$terminal[PRT$parent_list$node == node]) {
      pos <- "terminal"
    } else {
      parent_node <- PRT$parent_list$parent[PRT$parent_list$node == node]
      diff        <- node - parent_node
      pos         <- ifelse(diff == 1, "left", "right")
    }
    
    return(pos)
    
  }
  
  positions <- cbind.data.frame(node = 1:nnodes, position = unlist(lapply(1:nnodes, FUN = position_bl_plot)))

  
  library(ggparty)
  tree <- ggparty(py, terminal_space = 0.2) +
    geom_edge() +
    geom_edge_label(mapping = aes(label = edge_labels), parse = FALSE, size = 6) +
    #geom_node_splitvar() + 
    ggtitle(plot.title) + 
    geom_node_label(line_list = list(aes(label = node_labels1),
                                     aes(label = node_labels1.1),
                                     aes(label = node_labels2),
                                     aes(label = node_labels3)),
                    line_gpar = list(list(size = 12, col = "black",
                                          fontface = "bold", alignment = "center"),
                                     list(size = 12, col = "black",
                                          fontface = "bold", alignment = "center"),
                                     list(size = 12, col = "black",
                                          fontface = "italic", alignment = "center"),
                                     list(size = 12, col = "black",
                                          fontface = "italic", alignment = "center")), 
                    ids = "all", label.size = 1, 
                    label.padding = unit(1.5, "lines"), label.fill = "grey", label.r = unit(1.5, "lines")) +
    geom_node_plot(gglist = list(geom_point(data = predictions, aes(x = x, y = pred*100), color = "steelblue", 
                                            size = 2),
                                 theme_bw(),
                                 ylab("Survival [%]"),
                                 xlab("Time [years]"),
                                 ylim(ylim_surv),
                                 theme(text      = element_text(size = 10, color = "black"),
                                       axis.text = element_text(size = 10, color = "black"),
                                       #axis.text.x = element_blank(),
                                       #panel.grid.major.y = element_line(size = 2), 
                                       panel.grid.minor.y = element_blank(), 
                                       panel.grid.major.x = element_blank(),
                                       #plot.margin = margin(2, 2, 2, 2, "cm"), 
                                       legend.position = "none"),
                                 geom_line(data = predictions_mean, aes(x = x, y = mean_S*100), size = 1)), 
                   ids = "terminal", nudge_y = y_surv, nudge_x = -0.01, height = height_terminal, width = width_terminal) +
    geom_node_plot(gglist = list(geom_bar(data = bl %>% filter(id %in% (1:nnodes)[positions$position == "terminal"]), 
                                          aes(x = Baselearner, y = Selected), stat = "identity", fill = "steelblue"),
                                 theme_bw(),
                                 ylab("Selected"),
                                 #xlab("Baselearner"),
                                 scale_y_continuous(breaks = c(0,1), labels = c(0,1)), 
                                 #ylim(c(0, 1)),
                                 theme(text      = element_text(size = 10, color = "black"),
                                       axis.text = element_text(size = 10, color = "black"),
                                       axis.title.x = element_blank(),
                                       axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1, size = label_size_bl),
                                       #panel.grid.major.y = element_line(size = 2), 
                                       panel.grid.minor.y = element_blank(), 
                                       panel.grid.major.x = element_blank(),
                                       #plot.margin = margin(2, 2, 2, 2, "cm"), 
                                       legend.position = "none")),
                   ids = (1:nnodes)[positions$position == "terminal"], nudge_x = -0.001, nudge_y = y_bl_terminal, 
                   height = height_terminal, width = width_terminal) 
  
  
  
  if(max(PRT$parent_list$level) == 2) {
    warning("Return Terminal Node plots only")
    tree <- tree + theme(plot.title = element_text(size = 35, face = "bold", hjust = 0.5))
    return(tree)
  }
  
  
  if("left" %in% positions$position) {
    tree <- tree + geom_node_plot(gglist = list(geom_bar(data = bl %>% filter(id %in% (1:nnodes)[positions$position == "left"]), 
                                                         aes(x = Baselearner, y = Selected), stat = "identity", fill = "steelblue"),
                                                theme_bw(),
                                                ylab("Selected"),
                                                #xlab("Baselearner"),
                                                scale_y_continuous(breaks = c(0,1), labels = c(0,1)), 
                                                #ylim(c(0, 1)),
                                                theme(text      = element_text(size = 10, color = "black"),
                                                      axis.text = element_text(size = 10, color = "black"),
                                                      axis.title.x = element_blank(),
                                                      axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1, size = label_size_bl),
                                                      #panel.grid.major.y = element_line(size = 2), 
                                                      panel.grid.minor.y = element_blank(), 
                                                      panel.grid.major.x = element_blank(),
                                                      legend.position = "none")),
                                  ids = (1:nnodes)[positions$position == "left"], nudge_x = left_bl, nudge_y = y_bl,
                                  height = height_bl, width = width_bl)
  }
  
  if("right" %in% positions$position) {
    tree <- tree + geom_node_plot(gglist = list(geom_bar(data = bl %>% filter(id %in% (1:nnodes)[positions$position == "right"]), 
                                                         aes(x = Baselearner, y = Selected), stat = "identity", fill = "steelblue"),
                                                theme_bw(),
                                                ylab("Selected"),
                                                #xlab("Baselearner"),
                                                scale_y_continuous(breaks = c(0,1), labels = c(0,1)), 
                                                #ylim(c(0, 1)),
                                                theme(text      = element_text(size = 10, color = "black"),
                                                      axis.text = element_text(size = 10, color = "black"),
                                                      axis.title.x = element_blank(),
                                                      axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1, size = label_size_bl),
                                                      #panel.grid.major.y = element_line(size = 2), 
                                                      panel.grid.minor.y = element_blank(), 
                                                      panel.grid.major.x = element_blank(),
                                                      legend.position = "none")),
                                  ids = (1:nnodes)[positions$position == "right"], nudge_x = right_bl, nudge_y = y_bl,
                                  height = height_bl, width = width_bl)
  }
  if("one" %in% positions$position) {
    tree <- tree + geom_node_plot(gglist = list(geom_bar(data = bl %>% filter(id %in% (1:nnodes)[positions$position == "one"]), 
                                                         aes(x = Baselearner, y = Selected), stat = "identity", fill = "steelblue"),
                                                theme_bw(),
                                                ylab("Selected"),
                                                #xlab("Baselearner"),
                                                scale_y_continuous(breaks = c(0,1), labels = c(0,1)), 
                                                #ylim(c(0, 1)),
                                                theme(text      = element_text(size = 10, color = "black"),
                                                      axis.text = element_text(size = 10, color = "black"),
                                                      axis.title.x = element_blank(),
                                                      axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1, size = label_size_bl),
                                                      #panel.grid.major.y = element_line(size = 2), 
                                                      panel.grid.minor.y = element_blank(), 
                                                      panel.grid.major.x = element_blank(),
                                                      legend.position = "none")),
                                  ids = (1:nnodes)[positions$position == "one"], nudge_x = right_bl + 0.02, nudge_y = y_bl,
                                  height = height_bl, width = width_bl)
  }
  
  
  
  tree <- tree + theme(plot.title = element_text(size = 20, hjust = 0.5, face = "bold"))
  
  return(tree)
  
  
  
}



