🟦 Reconstructing Survival Curves from Bayesian PAMM Models in R (brms)
Brief:
Practical guide to reconstructing smooth, population-level survival curves with credible intervals from Bayesian Piecewise Exponential Additive Mixed Models (PAMMs) in R withbrms.
Uses vectorizeddata.tableapproach for computational efficiency with posterior samples.
📝 Summary
This note outlines a robust, computationally efficient workflow to derive smooth, population-level survival curves with Bayesian credible intervals from PAMMs fitted via brms. It covers posterior sampling, vectorized hazard computation, and proper uncertainty quantification at the population level.
🔑 Key Points
1️⃣ Bayesian Uncertainty Quantification
- Posterior Samples: Use
posterior_epred()to capture full Bayesian uncertainty - Population-Level Inference: Calculate credible intervals at population level, not averaged from individual uncertainties
- Computational Efficiency: Vectorized operations across all posterior samples simultaneously
2️⃣ Survival Curve Methodology
- Cumulative Hazard: \(H(t) = \sum_{ } \text{expected events}\)
- Survival Function: \(S(t) = \exp(-H(t))\)
- Population Averaging: Mean survival across subjects at each time point
3️⃣ Data Expansion & Grid Creation
- Fine Time Grid: Create dense intervals for smooth curves
- Subject Expansion: Cross all subjects with time grid
- Proper Offsets: Maintain log(interval) offsets for rate modeling
✅ "How-To" Steps
1️⃣ Extract Unique Subject-Level Covariates
# Extract unique combinations of subject characteristics
test <- data |>
dplyr::filter(id <= 100) |>
dplyr::select(-tstart, -tend, -interval, -offset, -ped_status) |>
dplyr::distinct()
2️⃣ Generate Sequential Intervals for Each Subject
# Define time parameters
max_time <- max(data$tend)
desired_interval <- 30
# Create time grid and interval dataframe
time_grid <- seq(0, max_time, by = desired_interval)
intervals_df <- data.frame(tstart = time_grid[-length(time_gr id)],
tend = time_grid[-1])
# Expand all subjects across time grid
expanded_data <- tidyr::crossing(test, intervals_df)
3️⃣ Rebuild Columns (All Intervals Positive)
# Reconstruct interval-specific variables
expanded_data <- expanded_data %>%
dplyr::mutate(
interval = tend - tstart, # Interval width
offset = log(interval), # Log offset for Poisson rate modeling
ped_status = 0 # No events for prediction
)
4️⃣ Get Posterior Samples for Bayesian Uncertainty
# Extract posterior samples of expected event counts
# This captures full Bayesian uncertainty in model parameters
epred_samples <- brms::posterior_epred(MV_PAMM_FOURTH, newdata = expanded_data)
n_samples <- nrow(epred_samples)
5️⃣ Vectorised Survival Calculation with data.table
library(data.table)
# Convert posterior samples to data.table format
# Transpose so each column represents one posterior sample
epred_dt <- as.data.table(t(epred_samples))
colnames(epred_dt) <- paste0("sample_", 1:n_samples)
# Combine expanded data with all posterior samples
expanded_dt <- cbind(as.data.table(expanded_data), epred_dt)
# Vectorized cumulative hazard calculation for ALL samples at once
# This replaces a slow for loop with fast vectorized operations
sample_cols <- paste0("sample_", 1:n_samples)
expanded_dt[order(tend), (paste0("cumhaz_", 1:n_samples)) := lapply(.SD, cumsum),
by = id, .SDcols = sample_cols]
# ^ For each subject (by = id), calculate cumulative sum of expected events
# ^ across ordered time points (order(tend)) for all posterior samples
# Vectorized survival probability calculation
survival_cols <- paste0("survival_", 1:n_samples)
cumhaz_cols <- paste0("cumhaz_", 1:n_samples)
expanded_dt[, (survival_cols) := lapply(.SD, function(x) exp(-x)), .SDcols = cumhaz_cols]
# ^ Apply exp(-cumhaz) transformation to convert cumulative hazard to survival
# Calculate population-level survival for each posterior sample
time_points <- sort(unique(expanded_dt$tend))
population_curves <- as.matrix(expanded_dt[, lapply(.SD, mean), by = tend, .SDcols = survival_cols][order(tend), ..survival_cols])
# ^ Average survival across all subjects at each time point
# ^ Results in one population curve per posterior sample
6️⃣ Summarize with Credible Intervals
# Calculate mean and credible intervals across posterior samples
summary_curve <- data.frame(
tend = time_points,
mean_survival = apply(population_curves, 1, mean), # Mean across samples
lower_ci = apply(population_curves, 1, quantile, 0.025), # 2.5th percentile
upper_ci = apply(population_curves, 1, quantile, 0.975) # 97.5th percentile
)
starting_point <- data.frame(
tend = 0,
mean_survival = 1,
lower_ci = 1,
upper_ci = 1
)
summary_curve <- rbind(starting_point, summary_curve)
# ^ Add starting point for survival at time 0
# ^ This ensures the curve starts at 100% survival
7️⃣ Create Publication-Ready Plot
library(ggplot2)
ggplot(summary_curve, aes(x = tend)) +
geom_ribbon(aes(ymin = lower_ci, ymax = upper_ci),
fill = "steelblue", alpha = 0.3) +
geom_line(aes(y = mean_survival),
color = "steelblue", size = 1.2) +
scale_y_continuous(limits = c(0, 1),
labels = scales::percent_format()) +
scale_x_continuous(expand = c(0, 0)) +
labs(
title = "Survival Curve with 95% Credible Intervals",
x = "Time",
y = "Survival Probability",
caption = "Shaded area represents 95% credible interval"
) +
theme_minimal() +
theme(
panel.grid.minor = element_blank(),
plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
axis.title = element_text(size = 12),
axis.text = element_text(size = 10)
)
💡 Key Learnings & Computational Notes
data.table Efficiency Gains
- Vectorization: Processes all 12,000+ posterior samples simultaneously
- Memory Efficiency: In-place operations with
:=assignment - Speed: 10-50x faster than equivalent
dplyrloops
Bayesian Methodology
- Proper Uncertainty: Population-level credible intervals capture true uncertainty
- Avoid Double-Averaging: Don't average within subjects first, then across subjects
- Full Posterior: Use all posterior samples for complete uncertainty quantification
Critical Technical Points
- Use
posterior_epred()for Bayesian uncertainty (notfitted()) - Maintain proper
offset = log(interval)for rate modeling - Order operations: cumulative hazard → survival → population average → credible intervals
🚀 Performance Benefits
This vectorized approach provides:
- Computational Speed: ~50x faster than nested loops
- Memory Efficiency: Processes large posterior samples without memory issues
- Scalability: Handles thousands of posterior samples and subjects efficiently
- Methodological Rigor: Proper Bayesian uncertainty quantification
How-To (Stratified)
Stratified Survival Curve Function:
stratified_survival_curves <- function(
data,
model,
covariate_name,
reference_id = 100,
max_time = NULL,
desired_interval = 30,
seed = 123) {
# 1. Extract base subject-level covariates (same as original)
base_test <- data %>%
dplyr::select(-tstart, -tend, -interval, -offset, -ped_status) |>
dplyr::distinct() |>
dplyr::slice_sample(n = min(reference_id, nrow(.)), seed = seed)
# 2. Get unique levels of stratification variable
unique_levels <- unique(data[[covariate_name]])
# 3. Create stratified test datasets - one row per stratum level
test <- purrr::map_dfr(unique_levels, function(level) {
modified_test <- base_test
modified_test[[covariate_name]] <- level
modified_test$stratum <- as.character(level)
return(modified_test)
})
# 4. Generate time grid (same as original)
if(is.null(max_time)) max_time <- max(data$tend)
time_grid <- seq(0, max_time, by = desired_interval)
intervals_df <- data.frame(tstart = time_grid[-length(time_grid)],
tend = time_grid[-1])
expanded_data <- tidyr::crossing(test, intervals_df)
# 5. Rebuild columns (same as original)
expanded_data <- expanded_data %>%
dplyr::mutate(
interval = tend - tstart,
offset = log(interval),
ped_status = 0
)
# 6. Get posterior samples (same as original)
epred_samples <- brms::posterior_epred(model, newdata = expanded_data)
n_samples <- nrow(epred_samples)
# 7. Vectorized calculation with stratification
library(data.table)
epred_dt <- as.data.table(t(epred_samples))
colnames(epred_dt) <- paste0("sample_", 1:n_samples)
expanded_dt <- cbind(as.data.table(expanded_data), epred_dt)
# Cumulative hazard BY STRATUM (note the grouping)
sample_cols <- paste0("sample_", 1:n_samples)
expanded_dt[order(tend), (paste0("cumhaz_", 1:n_samples)) := lapply(.SD, cumsum),
by = .(id, stratum), .SDcols = sample_cols]
# Survival calculation (same as original)
survival_cols <- paste0("survival_", 1:n_samples)
cumhaz_cols <- paste0("cumhaz_", 1:n_samples)
expanded_dt[, (survival_cols) := lapply(.SD, function(x) exp(-x)), .SDcols = cumhaz_cols]
# Population means BY STRATUM
time_points <- sort(unique(expanded_dt$tend))
# Calculate stratified results
results_list <- list()
for(stratum_level in unique_levels) {
stratum_data <- expanded_dt[stratum == stratum_level]
pop_curves <- as.matrix(stratum_data[, lapply(.SD, mean), by = tend,
.SDcols = survival_cols][order(tend), ..survival_cols])
results_list[[as.character(stratum_level)]] <- data.frame(
tend = time_points,
stratum = stratum_level,
mean_survival = apply(pop_curves, 1, mean),
lower_ci = apply(pop_curves, 1, quantile, 0.025),
upper_ci = apply(pop_curves, 1, quantile, 0.975)
)
}
# Combine all strata
return(do.call(rbind, results_list))
}
Usage:
# Generate stratified curves for a factor variable
stratified_results <- stratified_survival_curves(
data = PAMM_DATA_SECOND,
model = MV_PAMM_FOURTH,
covariate_name = "your_factor_variable"
)
# Plot stratified curves
ggplot(stratified_results, aes(x = tend, color = factor(stratum), fill = factor(stratum))) +
geom_ribbon(aes(ymin = lower_ci, ymax = upper_ci), alpha = 0.2) +
geom_line(aes(y = mean_survival), size = 1.2) +
scale_y_continuous(limits = c(0, 1), labels = scales::percent_format()) +
labs(title = "Stratified Survival Curves", x = "Time", y = "Survival Probability") +
theme_minimal()