Inspecting posteriors

if(!requireNamespace("fabricatr", quietly = TRUE)) {
  install.packages("fabricatr")
}

library(CausalQueries)
library(fabricatr)
library(knitr)
library(ggplot2)
library(rstan)
library(bayesplot)
rstan_options(refresh = 0)

Accessing the posterior

When you update a model using CausalQueries, CausalQueries generates and updates a stan model and saves the posterior distribution over parameters in the model.

The basic usage is:

data <- data.frame(X = rep(c(0:1), 10), Y = rep(c(0:1), 10))

model <- make_model("X -> Y") |> 
  update_model(data)

The posterior over parameters can be accessed thus:

grab(model, "posterior_distribution")
#> Summary statistics of model parameter posterior distributions:
#> : 4000 rows (draws) by 6 cols (parameters)
#> 
#>      mean   sd
#> X.0  0.50 0.10
#> X.1  0.50 0.10
#> Y.00 0.08 0.07
#> Y.10 0.04 0.04
#> Y.01 0.80 0.11
#> Y.11 0.08 0.07

When querying a model you can request use of the posterior distribution with the using argument:

model |> 
  query_model(
    query = "Y[X=1] > Y[X=0]",
    using = c("priors", "posteriors")) |>
  kable(digits = 2)
query given using case_level mean sd cred.low cred.high
Y[X=1] > Y[X=0] - priors FALSE 0.25 0.20 0.01 0.70
Y[X=1] > Y[X=0] - posteriors FALSE 0.80 0.11 0.54 0.95

Summary of stan performance

You can access a summary of the parameter values and convergence information as produced by stan thus:

grab(model, "stan_summary")
#> Inference for Stan model: simplexes.
#> 4 chains, each with iter=2000; warmup=1000; thin=1; 
#> post-warmup draws per chain=1000, total post-warmup draws=4000.
#> 
#>              mean se_mean   sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
#> X.0          0.50    0.00 0.10   0.29   0.43   0.50   0.57   0.70  2304    1
#> X.1          0.50    0.00 0.10   0.30   0.43   0.50   0.57   0.71  2304    1
#> Y.00         0.08    0.00 0.07   0.00   0.03   0.06   0.12   0.28  2024    1
#> Y.10         0.04    0.00 0.04   0.00   0.01   0.03   0.06   0.15  4064    1
#> Y.01         0.80    0.00 0.11   0.54   0.73   0.82   0.88   0.95  3669    1
#> Y.11         0.08    0.00 0.07   0.00   0.02   0.06   0.11   0.28  4184    1
#> X0.Y00       0.04    0.00 0.04   0.00   0.01   0.03   0.06   0.14  2049    1
#> X1.Y00       0.04    0.00 0.04   0.00   0.01   0.03   0.06   0.15  2030    1
#> X0.Y10       0.02    0.00 0.02   0.00   0.01   0.01   0.03   0.08  3946    1
#> X1.Y10       0.02    0.00 0.02   0.00   0.01   0.01   0.03   0.08  3872    1
#> X0.Y01       0.40    0.00 0.10   0.21   0.33   0.40   0.47   0.60  2439    1
#> X1.Y01       0.40    0.00 0.10   0.21   0.33   0.40   0.47   0.60  2842    1
#> X0.Y11       0.04    0.00 0.04   0.00   0.01   0.03   0.06   0.14  4130    1
#> X1.Y11       0.04    0.00 0.04   0.00   0.01   0.03   0.06   0.14  3772    1
#> lp__       -14.59    0.04 1.55 -18.54 -15.34 -14.24 -13.47 -12.65  1305    1
#> 
#> Samples were drawn using NUTS(diag_e) at Thu Apr 25 21:53:12 2024.
#> For each parameter, n_eff is a crude measure of effective sample size,
#> and Rhat is the potential scale reduction factor on split chains (at 
#> convergence, Rhat=1).

This summary provides information on the distribution of parameters as well as convergence diagnostics, summarized in the Rhat column. In the printout above the first 6 rows show the distribution of the model parameters; the next 8 rows show the distribution over transformed parameters, here the causal types. The last row shows the unnormalized log density on Stan’s unconstrained space which, as described in Stan documentation is intended to diagnose sampling efficiency and evaluate approximations.

See stan documentation for further details.

Advanced diagnostics

If you are interested in advanced diagnostics of performance you can save and access the raw stan output.

model <- make_model("X -> Y") |> 
  update_model(data, keep_fit = TRUE)

Note that the summary for this raw output shows the labels used in the generic stan model: lambda for the vector of parameters, corresponding to the parameters in the parameters dataframe (grab(model, "parameters_df")), and , if saved, a vector types for the causal types (see grab(model, "causal_types")) and w for the event probabilities (grab(model, "event_probabilities")).

model |> grab("stan_fit")
#> Inference for Stan model: simplexes.
#> 4 chains, each with iter=2000; warmup=1000; thin=1; 
#> post-warmup draws per chain=1000, total post-warmup draws=4000.
#> 
#>              mean se_mean   sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
#> lambdas[1]   0.50    0.00 0.10   0.30   0.42   0.50   0.57   0.69  2248    1
#> lambdas[2]   0.50    0.00 0.10   0.31   0.43   0.50   0.58   0.70  2248    1
#> lambdas[3]   0.08    0.00 0.07   0.00   0.03   0.06   0.12   0.27  2181    1
#> lambdas[4]   0.04    0.00 0.04   0.00   0.01   0.03   0.06   0.14  4135    1
#> lambdas[5]   0.80    0.00 0.11   0.55   0.74   0.82   0.88   0.95  3670    1
#> lambdas[6]   0.08    0.00 0.07   0.00   0.03   0.06   0.11   0.27  4110    1
#> types[1]     0.04    0.00 0.04   0.00   0.01   0.03   0.06   0.14  2173    1
#> types[2]     0.04    0.00 0.04   0.00   0.01   0.03   0.06   0.14  2214    1
#> types[3]     0.02    0.00 0.02   0.00   0.01   0.01   0.03   0.08  3987    1
#> types[4]     0.02    0.00 0.02   0.00   0.01   0.01   0.03   0.07  3754    1
#> types[5]     0.40    0.00 0.10   0.21   0.33   0.39   0.46   0.60  2524    1
#> types[6]     0.40    0.00 0.10   0.22   0.33   0.40   0.47   0.60  2567    1
#> types[7]     0.04    0.00 0.04   0.00   0.01   0.03   0.05   0.14  3974    1
#> types[8]     0.04    0.00 0.04   0.00   0.01   0.03   0.06   0.14  3811    1
#> lp__       -14.58    0.04 1.54 -18.27 -15.39 -14.23 -13.43 -12.64  1437    1
#> 
#> Samples were drawn using NUTS(diag_e) at Thu Apr 25 21:53:13 2024.
#> For each parameter, n_eff is a crude measure of effective sample size,
#> and Rhat is the potential scale reduction factor on split chains (at 
#> convergence, Rhat=1).

You can then use diagnostic packages such as bayesplot.

model |> grab("stan_fit") |>
  bayesplot::mcmc_pairs(pars = c("lambdas[3]", "lambdas[4]", "lambdas[5]", "lambdas[6]"))

np <- model |> grab("stan_fit") |> bayesplot::nuts_params()
head(np) |> kable()
Chain Iteration Parameter Value
1 1 accept_stat__ 0.9778075
1 2 accept_stat__ 0.9993340
1 3 accept_stat__ 0.9736244
1 4 accept_stat__ 0.9478284
1 5 accept_stat__ 0.9969060
1 6 accept_stat__ 0.9388957

model |> grab("stan_fit") |>
  bayesplot::mcmc_trace(pars = "lambdas[5]", np = np) 
#> No divergences to plot.