rm(list = ls())

# source required functions and load necessary packages
source("R/Functions.R")

library(batchtools)
library(tidyverse)
library(magrittr)
library(pseudo)
library(mboost)
library(partykit)
library(stringr)
library(VGAM)
library(Matrix)
library(ggparty)
library(cowplot)
library(ggtext)
library(grid)
library(gridExtra)
library(pander)
library(survival)
library(qdapRegex) 



# load example data and functions
data(cancer, package = "survival")
data <- rotterdam %>%
  filter(nodes > 0) %>% # filter for nodes > 0 (as in Royston and Altman, 2013)
  mutate(meno   = as.factor(meno),
         grade  = as.factor(grade),
         hormon = as.factor(hormon),
         chemo  = as.factor(chemo))


# define covariates for tree building and node-wise gradient boosting
covariates    <- c("age", "meno", "size", "grade", "nodes", 
                   "pgr",  "er", "hormon", "chemo")



# check distribution of observed times in the data 
summary(data$dtime) 

# define time points
cutoffs_pseudo <- c(1:5)*365.25


# model building
prt <- PRT(data = data, covariates = covariates, cutoffs_pseudo = cutoffs_pseudo, 
           time_col = "dtime", status_col = "death", id_col = "pid", 
           D = 3, fold = 5, center = FALSE)

# predictions
pred_PRT <- predict.PRT(prt)

pred_PRT %>% filter(pid %in% c(10, 12, 19)) %>% # show parts of the results for end nodes 4,7,8
  arrange(end_node) %>% 
  dplyr::select(c("nodes", "size", "meno", "er", "pgr", "time", "pred", "end_node")) 


# extract selected base-learner 
baselearner <- get.baselearner(prt)
baselearner$selected_bl   # data frame indicating which base-learner was selected
#baselearner$estimates_bl # list containing the estimates for the base-learner in each node


# extract split variables
splits   <- analyze.splits(prt)
splits$splits_per_endnode # data frame indicating which split variables were selected in each path
splits$splits_list        # data frame containing the split variables and levels for each node in the regression tree


# define covariate labels in the plot
lbls         <- covariates
names(lbls)  <- covariates


# plot results
plot.PRT(PRT_object = prt, covariate_labels = lbls, pred = pred_PRT, 
         baselearner = baselearner$selected_bl, 
         covariate_labels_bl = lbls, work_data = data,
         plot.title = "Illustration of PRT", label_size_bl = 8, 
         left_bl  = -0.08, right_bl = 0.08, 
         y_bl = -0.015, y_surv = -0.07, y_bl_terminal = -0.17, 
         height_terminal = 0.5, width_terminal = 0.7, height_bl = 0.5, 
         width_bl = 0.7, ylim_surv = c(40,100))
