Skip to content

🟦 Bayesian Cox Survival Curves with brms: Manual M/I-Spline Reconstruction and Full Posterior Inference

Brief:
End-to-end guide for fitting flexible Bayesian Cox proportional hazards models with brms and manually reconstructing population-level survival curves with credible intervals, using all posterior draws and explicit spline mathematics.


📝 Summary

This workflow details, step by step, how to extract population survival curves (with Bayesian credible intervals) from a Bayesian Cox model, explicitly handling the M-spline/I-spline structure used by brms for the flexible baseline hazard.


🔑 Key Points

1️⃣ Bayesian Cox Model with M-spline Baseline

  • M-splines (5 by default): Flexible, nonparametric way to model the baseline hazard in brms
  • Population-level curve: Averaging over the sample’s covariate distribution (not just a hypothetical "average" patient)
  • All posterior draws: Full propagation of model uncertainty

2️⃣ Survival & Hazard Formulae

  • Individual hazard:
    $$ h(t|x) = h_0(t) \times \exp(\beta_0 + \beta_1\, \text{age}) $$
  • Survival:
    $$S(t|x) = \exp(-H_0(t) \times \exp(\beta_0 + \beta_1\, \text{age})) $$
    Where \(H_0(t)\) is built from I-splines (the integrated M-spline basis).
  • Population survival:
    Mean across all real patients, for each draw and time point.

✅ "How-To" Steps

1️⃣ Fit the Bayesian Cox Model

# ============================================================================
# Bayesian Cox Proportional Hazards Survival Analysis with Manual Curve Extraction
# ============================================================================

library(survival)
library(brms)
library(splines2)
library(dplyr)

# Load VA lung cancer dataset (137 patients)
veteran <- survival::veteran  

# Fit Cox proportional hazards model using brms
# - Uses M-splines (5 basis functions by default) for flexible baseline hazard
# - Avoids parametric assumptions about hazard shape over time
# - Age as continuous covariate with proportional hazards assumption
cox_vet <- brms::brm(
  formula = time | cens(1 - status) ~ age,  # Survival outcome with censoring
  family = brms::cox(),                     # Cox proportional hazards
  data = veteran,
  cores = 10,
  chains = 4,
  iter = 3000,
  warmup = 1000,
  seed = 123
)

2️⃣ Extract Posterior Samples

# Get posterior draws for manual survival curve calculation
cox_draws <- brms::as_draws_df(cox_vet)

3️⃣ Set Up Population and Time Grid for Predictions

# Use entire patient population for proper age distribution weighting
# This creates population-averaged survival curves, not individual predictions
all_patients <- veteran |> 
  dplyr::select(age)
n_patients <- nrow(all_patients)  # 137 patients

# Create regular time grid for survival curve evaluation
max_time <- max(veteran$time)
pamm_time_grid <- seq(0, max_time, by = 1)
n_time_points <- length(pamm_time_grid)

cat("Predicting for", n_patients, "patients over", n_time_points, "time points\n")

4️⃣ Extract Spline Parameters and Posterior Coefficients

# Extract M-spline parameters that brms used for baseline hazard
# To manually recreate survival curves, we need the exact same spline basis
# that brms used internally when fitting the model

# Internal knots: points where spline pieces connect (usually at data quantiles)
knots <- attr(cox_vet$basis$dpars$mu$bhaz$basis_matrix, "knots")

# Boundary knots: start/end points of spline domain (usually min/max survival times)  
boundary_knots <- attr(cox_vet$basis$dpars$mu$bhaz$basis_matrix, "Boundary.knots")

# Degree: polynomial degree of each spline piece (typically 3 for smooth curves)
degree <- attr(cox_vet$basis$dpars$mu$bhaz$basis_matrix, "degree")

# Why extract these? The baseline hazard h₀(t) was modeled using M-splines with
# these exact parameters. To calculate survival curves manually, we need to 
# recreate the same I-spline basis (integral of M-splines) using identical settings.

# Extract posterior samples of model parameters
age_coef <- cox_draws$b_age                    # Age effect coefficient
intercept_coef <- cox_draws$b_Intercept        # Baseline hazard intercept
sbhaz_cols <- paste0("sbhaz[", 1:5, "]")       # M-spline basis coefficients
sbhaz_coefs <- cox_draws[, sbhaz_cols]

n_draws <- nrow(cox_draws)
survival_matrix <- matrix(NA, nrow = n_draws, ncol = n_time_points)

🧠 Mathematical/Model Commentary

KEY INSIGHT:
brms Cox models work as follows:
1. Baseline hazard \(h_0(t)\) is modeled using M-splines (hazard rate). 2. Cumulative baseline hazard \(H_0(t)\) uses I-splines (integral of M-splines). 3. Individual hazard: \(h(t|x) = h_0(t) \exp(\beta_0 + \beta_1\, \text{age})\) 4. Survival function: \(S(t|x) = \exp(-H_0(t) \exp(\beta_0 + \beta_1\, \text{age}))\)

Breakdown & Intuition: - \(h_0(t)\) = baseline hazard (from M-splines) - \(\beta_0\) = intercept, \(\beta_1\) = age log-hazard ratio per year - \(\exp(\beta_0 + \beta_1\cdot \text{age})\) = patient-specific hazard multiplier - \(H_0(t)\) = cumulative hazard up to t (I-spline) - More cumulative hazard (or higher risk/multiplier) → lower survival probability - Survival is always \(S(t) = \exp(-H(t))\), ensuring values in [0,1]

Some numeric examples: - \(H = 0 \rightarrow S = 1.0\) (100%) - \(H = 0.7 \rightarrow S = 0.5\) (50%) - \(H = 2.3 \rightarrow S = 0.1\) (10%)

About M-splines: - Non-negative, locally supported spline basis for \(h_0(t)\) - Default: 5 basis, quantile-based knots in brms - Weighted sum can make hazard U-, increasing, decreasing, multi-modal, etc.


5️⃣ Manual Survival Curve Calculation Using I-splines

# Create I-spline basis for cumulative hazard at our prediction times
ibasis_pred <- iSpline(
  pamm_time_grid,
  knots = knots,
  Boundary.knots = boundary_knots,
  degree = degree
)

# Loop through each posterior draw to create survival curve distribution
for (i in 1:n_draws) {
  if(i %% 1000 == 0) cat("Processing draw", i, "of", n_draws, "\n")

  # Extract parameters for this posterior draw
  beta_age <- age_coef[i]         # β₁ - age coefficient
  beta_intercept <- intercept_coef[i]  # β₀ - intercept coefficient
  sbhaz_draw <- as.numeric(sbhaz_coefs[i, ])  # M-spline coefficients

  # Calculate cumulative baseline hazard H₀(t) at prediction times
  cumulative_hazard_pred <- as.numeric(ibasis_pred %*% sbhaz_draw)
  # (%*% means matrix multiplication)
  # Calculate survival for each patient at each time point
  patient_survivals <- matrix(NA, nrow = n_patients, ncol = n_time_points)

  for (j in 1:n_patients) {
    # Linear predictor and hazard ratio multiplier
    linear_pred <- beta_intercept + beta_age * all_patients$age[j]
    mu <- exp(linear_pred)
    # Apply Cox survival function: S(t|x) = exp(-H₀(t) × exp(β₀ + β₁×age))
    patient_survivals[j, ] <- exp(-cumulative_hazard_pred * mu)
  }

  # Population-level survival curve: average across all patients
  survival_matrix[i, ] <- colMeans(patient_survivals)
}

6️⃣ Summarize Posterior Distribution of Survival Curves

# Calculate posterior mean and credible intervals for survival probabilities
summary_curve_cox <- data.frame(
  tend = pamm_time_grid,
  mean_survival = apply(survival_matrix, 2, mean),
  lower_ci = apply(survival_matrix, 2, quantile, 0.025),
  upper_ci = apply(survival_matrix, 2, quantile, 0.975)
)

# Save results for comparison with other methods
readr::write_rds(summary_curve_cox, here::here("docs/write/sub_projects/12_PAMM/summary_curve_cox.Rds"))

cat("Successfully generated Bayesian Cox survival curves\n")
cat("Final survival at t =", max(pamm_time_grid), ":", 
    round(tail(summary_curve_cox$mean_survival, 1), 3), "\n")

💡 Key Learnings & Notes

  • Posterior propagation: All parameter uncertainty is included via looping over posterior draws—true Bayesian credible intervals.
  • Spline mathematics: M-splines (baseline hazard) + I-splines (cumulative) are required for full transparency in curve extraction.
  • Population-averaged inference: Curve averages survival probabilities across real patients at each time, reflecting sample covariate distribution.
  • No hidden averaging: No shortcuts like plugging in a mean age; every step honors full joint uncertainty.