🟦 Bayesian Cox Survival Curves with brms: Manual M/I-Spline Reconstruction and Full Posterior Inference
Brief:
End-to-end guide for fitting flexible Bayesian Cox proportional hazards models withbrmsand manually reconstructing population-level survival curves with credible intervals, using all posterior draws and explicit spline mathematics.
📝 Summary
This workflow details, step by step, how to extract population survival curves (with Bayesian credible intervals) from a Bayesian Cox model, explicitly handling the M-spline/I-spline structure used by brms for the flexible baseline hazard.
🔑 Key Points
1️⃣ Bayesian Cox Model with M-spline Baseline
- M-splines (5 by default): Flexible, nonparametric way to model the baseline hazard in
brms - Population-level curve: Averaging over the sample’s covariate distribution (not just a hypothetical "average" patient)
- All posterior draws: Full propagation of model uncertainty
2️⃣ Survival & Hazard Formulae
- Individual hazard:
$$ h(t|x) = h_0(t) \times \exp(\beta_0 + \beta_1\, \text{age}) $$ - Survival:
$$S(t|x) = \exp(-H_0(t) \times \exp(\beta_0 + \beta_1\, \text{age})) $$
Where \(H_0(t)\) is built from I-splines (the integrated M-spline basis). - Population survival:
Mean across all real patients, for each draw and time point.
✅ "How-To" Steps
1️⃣ Fit the Bayesian Cox Model
# ============================================================================
# Bayesian Cox Proportional Hazards Survival Analysis with Manual Curve Extraction
# ============================================================================
library(survival)
library(brms)
library(splines2)
library(dplyr)
# Load VA lung cancer dataset (137 patients)
veteran <- survival::veteran
# Fit Cox proportional hazards model using brms
# - Uses M-splines (5 basis functions by default) for flexible baseline hazard
# - Avoids parametric assumptions about hazard shape over time
# - Age as continuous covariate with proportional hazards assumption
cox_vet <- brms::brm(
formula = time | cens(1 - status) ~ age, # Survival outcome with censoring
family = brms::cox(), # Cox proportional hazards
data = veteran,
cores = 10,
chains = 4,
iter = 3000,
warmup = 1000,
seed = 123
)
2️⃣ Extract Posterior Samples
3️⃣ Set Up Population and Time Grid for Predictions
# Use entire patient population for proper age distribution weighting
# This creates population-averaged survival curves, not individual predictions
all_patients <- veteran |>
dplyr::select(age)
n_patients <- nrow(all_patients) # 137 patients
# Create regular time grid for survival curve evaluation
max_time <- max(veteran$time)
pamm_time_grid <- seq(0, max_time, by = 1)
n_time_points <- length(pamm_time_grid)
cat("Predicting for", n_patients, "patients over", n_time_points, "time points\n")
4️⃣ Extract Spline Parameters and Posterior Coefficients
# Extract M-spline parameters that brms used for baseline hazard
# To manually recreate survival curves, we need the exact same spline basis
# that brms used internally when fitting the model
# Internal knots: points where spline pieces connect (usually at data quantiles)
knots <- attr(cox_vet$basis$dpars$mu$bhaz$basis_matrix, "knots")
# Boundary knots: start/end points of spline domain (usually min/max survival times)
boundary_knots <- attr(cox_vet$basis$dpars$mu$bhaz$basis_matrix, "Boundary.knots")
# Degree: polynomial degree of each spline piece (typically 3 for smooth curves)
degree <- attr(cox_vet$basis$dpars$mu$bhaz$basis_matrix, "degree")
# Why extract these? The baseline hazard h₀(t) was modeled using M-splines with
# these exact parameters. To calculate survival curves manually, we need to
# recreate the same I-spline basis (integral of M-splines) using identical settings.
# Extract posterior samples of model parameters
age_coef <- cox_draws$b_age # Age effect coefficient
intercept_coef <- cox_draws$b_Intercept # Baseline hazard intercept
sbhaz_cols <- paste0("sbhaz[", 1:5, "]") # M-spline basis coefficients
sbhaz_coefs <- cox_draws[, sbhaz_cols]
n_draws <- nrow(cox_draws)
survival_matrix <- matrix(NA, nrow = n_draws, ncol = n_time_points)
🧠 Mathematical/Model Commentary
KEY INSIGHT:
brmsCox models work as follows:
1. Baseline hazard \(h_0(t)\) is modeled using M-splines (hazard rate). 2. Cumulative baseline hazard \(H_0(t)\) uses I-splines (integral of M-splines). 3. Individual hazard: \(h(t|x) = h_0(t) \exp(\beta_0 + \beta_1\, \text{age})\) 4. Survival function: \(S(t|x) = \exp(-H_0(t) \exp(\beta_0 + \beta_1\, \text{age}))\)
Breakdown & Intuition: - \(h_0(t)\) = baseline hazard (from M-splines) - \(\beta_0\) = intercept, \(\beta_1\) = age log-hazard ratio per year - \(\exp(\beta_0 + \beta_1\cdot \text{age})\) = patient-specific hazard multiplier - \(H_0(t)\) = cumulative hazard up to t (I-spline) - More cumulative hazard (or higher risk/multiplier) → lower survival probability - Survival is always \(S(t) = \exp(-H(t))\), ensuring values in [0,1]
Some numeric examples: - \(H = 0 \rightarrow S = 1.0\) (100%) - \(H = 0.7 \rightarrow S = 0.5\) (50%) - \(H = 2.3 \rightarrow S = 0.1\) (10%)
About M-splines: - Non-negative, locally supported spline basis for \(h_0(t)\) - Default: 5 basis, quantile-based knots in brms - Weighted sum can make hazard U-, increasing, decreasing, multi-modal, etc.
5️⃣ Manual Survival Curve Calculation Using I-splines
# Create I-spline basis for cumulative hazard at our prediction times
ibasis_pred <- iSpline(
pamm_time_grid,
knots = knots,
Boundary.knots = boundary_knots,
degree = degree
)
# Loop through each posterior draw to create survival curve distribution
for (i in 1:n_draws) {
if(i %% 1000 == 0) cat("Processing draw", i, "of", n_draws, "\n")
# Extract parameters for this posterior draw
beta_age <- age_coef[i] # β₁ - age coefficient
beta_intercept <- intercept_coef[i] # β₀ - intercept coefficient
sbhaz_draw <- as.numeric(sbhaz_coefs[i, ]) # M-spline coefficients
# Calculate cumulative baseline hazard H₀(t) at prediction times
cumulative_hazard_pred <- as.numeric(ibasis_pred %*% sbhaz_draw)
# (%*% means matrix multiplication)
# Calculate survival for each patient at each time point
patient_survivals <- matrix(NA, nrow = n_patients, ncol = n_time_points)
for (j in 1:n_patients) {
# Linear predictor and hazard ratio multiplier
linear_pred <- beta_intercept + beta_age * all_patients$age[j]
mu <- exp(linear_pred)
# Apply Cox survival function: S(t|x) = exp(-H₀(t) × exp(β₀ + β₁×age))
patient_survivals[j, ] <- exp(-cumulative_hazard_pred * mu)
}
# Population-level survival curve: average across all patients
survival_matrix[i, ] <- colMeans(patient_survivals)
}
6️⃣ Summarize Posterior Distribution of Survival Curves
# Calculate posterior mean and credible intervals for survival probabilities
summary_curve_cox <- data.frame(
tend = pamm_time_grid,
mean_survival = apply(survival_matrix, 2, mean),
lower_ci = apply(survival_matrix, 2, quantile, 0.025),
upper_ci = apply(survival_matrix, 2, quantile, 0.975)
)
# Save results for comparison with other methods
readr::write_rds(summary_curve_cox, here::here("docs/write/sub_projects/12_PAMM/summary_curve_cox.Rds"))
cat("Successfully generated Bayesian Cox survival curves\n")
cat("Final survival at t =", max(pamm_time_grid), ":",
round(tail(summary_curve_cox$mean_survival, 1), 3), "\n")
💡 Key Learnings & Notes
- Posterior propagation: All parameter uncertainty is included via looping over posterior draws—true Bayesian credible intervals.
- Spline mathematics: M-splines (baseline hazard) + I-splines (cumulative) are required for full transparency in curve extraction.
- Population-averaged inference: Curve averages survival probabilities across real patients at each time, reflecting sample covariate distribution.
- No hidden averaging: No shortcuts like plugging in a mean age; every step honors full joint uncertainty.