This is an exemplary illustration of the Pseudo-value regression trees (PRT) method applied on data from the Rotterdam tumor bank (Royston and Altman (2013)) provided in the survival package in R. The outcome of interest considered for this illustrative example is overall survival. Information on R the version this code was written on can be found below.
# version info
version
## _
## platform x86_64-w64-mingw32
## arch x86_64
## os mingw32
## system x86_64, mingw32
## status
## major 4
## minor 1.0
## year 2021
## month 05
## day 18
## svn rev 80317
## language R
## version.string R version 4.1.0 (2021-05-18)
## nickname Camp Pontanezen
The file Functions.R contains the implementation of PRT as well as functions related for analyzing the PRT method output.
rm(list = ls())
source("R/Functions.R")
library(batchtools)
library(tidyverse)
library(magrittr)
library(pseudo)
library(mboost)
library(partykit)
library(stringr)
library(VGAM)
library(Matrix)
library(ggparty)
library(cowplot)
library(ggtext)
library(grid)
library(gridExtra)
library(pander)
library(survival)
library(qdapRegex)
# load example data and functions
data(cancer, package = "survival")
data <- rotterdam %>%
filter(nodes > 0) %>% # filter for nodes > 0 (as in Royston and Altman, 2013)
mutate(meno = as.factor(meno),
grade = as.factor(grade),
hormon = as.factor(hormon),
chemo = as.factor(chemo))
The data frame data contains \(n = 1546\) female patients with primary breast cancer. Further, data contains the columns age at surgery (age), measured in years), menopausal status (meno, two categories, pre-(=0)/ post-(=1) menopausal), tumor size (size, three categories, <=20, 20-50, >50), differentiation grade (grade, two categories, 2/2), number of positive lymph nodes (nodes), progesterone receptor (pgr, measured in \(fmol/l\)), estrogen receptors (er, measured in \(fmol/l\)), hormonal treatment (hormon, two categories, yes(1)/no(0)), chemotherapy (chemo, two categories, yes(1)/no(0)) (see Royston and Altman (2013) for details on the data). Further, time (in days) and event status (0 = censored, 1 = event) for overall survival are contained in the columns dtime and death respectively.
We include all aforementioned available covariates into the analysis.
# define covariates for tree building and node-wise gradient boosting
covariates <- c("age", "meno", "size", "grade", "nodes",
"pgr", "er", "hormon", "chemo")
We chose five time points (\(t_1, \ldots , t_5\)) corresponding to 1 to 5 years of follow-up. Note that the maximal cutoff value should not exceed the maximal observation time (the maximal value in the dtime column).
# check distribution of observed times in the data
summary(data$dtime)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 45 1175 2166 2300 3275 7027
# define time points
cutoffs_pseudo <- c(1:5)*365.25
The main function PRT performs the following steps
# model building
prt <- PRT(data = data, covariates = covariates, cutoffs_pseudo = cutoffs_pseudo,
time_col = "dtime", status_col = "death", id_col = "pid",
D = 3, fold = 5, center = FALSE)
## [1] "Calculate pseudo-values..."
## [1] "Optimize mstop via cross-validation..."
## [1] "Fit PRT with optimal mstop on all data..."
The function PRT takes the following arguments:
The function PRT returns an object containing the following elements:
The function predict.PRT returns estimated survival probabilities (column pred) in a long data format attached to the original data for each observation. Additionally, the terminal node the observations falls into is attached (column end_node).
# predictions
pred_PRT <- predict.PRT(prt)
pred_PRT %>% filter(pid %in% c(10, 12, 19)) %>% # show parts of the results for end nodes 4,7,8
arrange(end_node) %>%
dplyr::select(c("nodes", "size", "meno", "er", "pgr", "time", "pred", "end_node"))
## # A tibble: 15 x 8
## nodes size meno er pgr time pred end_node
## <int> <fct> <fct> <int> <int> <dbl> <dbl> <dbl>
## 1 1 <=20 0 67 245 365 0.997 4
## 2 1 <=20 0 67 245 730 0.978 4
## 3 1 <=20 0 67 245 1095 0.930 4
## 4 1 <=20 0 67 245 1461 0.882 4
## 5 1 <=20 0 67 245 1826 0.829 4
## 6 1 20-50 0 0 14 365 0.983 7
## 7 1 20-50 0 0 14 730 0.928 7
## 8 1 20-50 0 0 14 1095 0.838 7
## 9 1 20-50 0 0 14 1461 0.770 7
## 10 1 20-50 0 0 14 1826 0.702 7
## 11 5 20-50 0 976 316 365 0.961 8
## 12 5 20-50 0 976 316 730 0.876 8
## 13 5 20-50 0 976 316 1095 0.764 8
## 14 5 20-50 0 976 316 1461 0.684 8
## 15 5 20-50 0 976 316 1826 0.610 8
The function get.baselearner applied on the object returned by the PRT function returns a list containing (i) a data frame indicating which base-learner was selected (1 = yes, 0 = no) in which node-wise boosting model (stored in selected_bl) and (ii) a list containing the estimates for the base-learner in each node (stored in estimates_bl).
# extract selected base-learner
baselearner <- get.baselearner(prt)
baselearner$selected_bl # data frame indicating which base-learner was selected
## age meno size grade nodes pgr er hormon chemo time x0
## Node 1 0 0 1 1 1 1 1 1 1 0 0
## Node 2 0 0 1 1 0 0 1 1 1 0 0
## Node 3 0 1 0 1 0 0 1 0 1 0 0
## Node 4 0 0 0 1 0 0 1 0 0 0 0
## Node 5 0 0 0 0 0 0 1 0 0 0 0
## Node 6 0 0 0 1 1 0 1 1 1 0 0
## Node 7 0 1 0 0 0 0 1 1 1 0 0
## Node 8 0 0 0 1 0 0 1 0 1 0 0
## Node 9 0 0 1 1 0 1 1 0 0 0 0
## Node 10 0 0 1 1 0 0 1 1 0 0 0
## Node 11 0 0 0 1 0 0 1 0 0 0 0
## Node 12 0 0 1 0 0 0 1 0 1 0 0
## Node 13 0 0 1 0 0 0 1 0 0 0 0
## Node 14 0 0 0 0 0 0 1 1 0 0 0
## Node 15 0 1 0 0 0 0 1 0 0 0 0
#baselearner$estimates_bl # list containing the estimates for the base-learner in each node
The function analyze.splits applied on the object returned by the PRT function returns a list containing (i) a data frame indicating which split variables were selected (1 = yes, 0 = no) in the path to an end node (stored in columns of splits_per_endnode) and (ii) a data frame containing the split variables and levels for each node in the regression tree (stored in splits_list).
# extract split variables
splits <- analyze.splits(prt)
splits$splits_per_endnode # data frame indicating which split variables were selected in each path
## 4 5 7 8 11 12 14 15
## age 0 0 0 0 0 0 0 0
## meno 1 1 0 0 0 0 0 0
## size 1 1 1 1 0 0 1 1
## grade 0 0 0 0 0 0 0 0
## nodes 1 1 1 1 1 1 1 1
## pgr 0 0 0 0 1 1 1 1
## er 0 0 0 0 1 1 0 0
## hormon 0 0 0 0 0 0 0 0
## chemo 0 0 0 0 0 0 0 0
splits$splits_list # data frame containing the split variables and levels for each node in the regression tree
## node split_variable level
## 1 1 nodes 1
## 2 2 size 2
## 3 3 meno 3
## 4 6 nodes 3
## 5 9 pgr 2
## 6 10 er 3
## 7 13 size 3
The function plot.PRT plots the tree structure with selected base-learners at each node and estimated survival probabilities in the terminal nodes.
# define covariate labels in the plot
lbls <- covariates
names(lbls) <- covariates
# plot results
plot.PRT(PRT_object = prt, covariate_labels = lbls, pred = pred_PRT,
baselearner = baselearner$selected_bl,
covariate_labels_bl = lbls, work_data = data,
plot.title = "Illustration of PRT", label_size_bl = 8,
left_bl = -0.08, right_bl = 0.08,
y_bl = -0.015, y_surv = -0.07, y_bl_terminal = -0.17,
height_terminal = 0.5, width_terminal = 0.7, height_bl = 0.5,
width_bl = 0.7, ylim_surv = c(40,100))
The arguments of the plot.PRT function are given in the following table.
Institute of Medical Biometry, Informatics and Epidemiology, University of Bonn, schenk@imbie.uni-bonn.de↩︎