はじめに
本記事では線形回帰モデルを例にベイズファクターの計算手法を複数実践します。また結果を比較しそれぞれの手法の特徴をみていきます。
実践する手法は以下の通り。ガウス求積法のみ解析的な手法(積分の推定を含む)で、他の手法はMCMCの結果を使った推定手法となります。上2つの手法はこちらの記事で詳しく説明しています。
線形回帰モデル
線形回帰モデルを以下のようにおきます。数式中の文字はこれ以降こちらの記事に従います。
$$ p(\boldsymbol{Y}) = \mathrm{Normal}\left(\boldsymbol{\mu}, \sigma^2 \boldsymbol{I}_N\right) = \cfrac{1}{(2\pi)^{n/2}\sigma^N}\exp\left( -\cfrac{1}{2\sigma^{2}}\left( \boldsymbol{Y} - \boldsymbol{\mu} \right)^T\left(\boldsymbol{Y} - \boldsymbol{\mu}\right) \right) \tag{1} $$
事前分布は客観ベイズの考えに基づきZellner-Siow’s Priorsを採用します。
ガウス求積法
これは解析的にベイズファクターを求める手法です。算出には BayesFactor::regressionBF()
を使います。
この関数では下記の積分を含む式を、ガウス求積法を使って高精度に推定することでベイズファクターを算出します。
スクリプトは以下の通り。
# data preparation --------------------------------------------------------
library(tidyverse)
d <- attitude
response_variable <- d$rating
predictors <- d %>% select(-rating) %>% mutate_all(center) %>% as_tibble()
# regressionBF ------------------------------------------------------------
library(BayesFactor)
bf <- regressionBF(rating ~ ., data=d, rscaleCont = "medium")
bf
# Bayes factor analysis
# --------------
# [1] complaints : 417938.6 ±0.01%
# [2] privileges : 3.177784 ±0%
# [3] learning : 103.393 ±0%
# [4] raises : 47.00917 ±0.01%
# [5] critical : 0.4493186 ±0%
# [6] advance : 0.4472295 ±0%
# [7] complaints + privileges : 75015.23 ±0%
# [8] complaints + learning : 207271.9 ±0%
# [9] complaints + raises : 77498.99 ±0%
# [10] complaints + critical : 70087.3 ±0%
# [11] complaints + advance : 72759.76 ±0%
# [12] privileges + learning : 42.16342 ±0.01%
# [13] privileges + raises : 25.89924 ±0%
# [14] privileges + critical : 1.382939 ±0%
# [15] privileges + advance : 1.234018 ±0%
# [16] learning + raises : 100.5111 ±0.01%
# [17] learning + critical : 33.96788 ±0%
# [18] learning + advance : 68.89183 ±0.01%
# [19] raises + critical : 15.70954 ±0%
# [20] raises + advance : 35.60357 ±0%
# [21] critical + advance : 0.2393626 ±0.01%
# [22] complaints + privileges + learning : 56341.12 ±0%
# [23] complaints + privileges + raises : 18230.79 ±0%
# [24] complaints + privileges + critical : 16235.87 ±0%
# [25] complaints + privileges + advance : 16482.37 ±0%
# [26] complaints + learning + raises : 42847.98 ±0%
# [27] complaints + learning + critical : 42374.73 ±0%
# [28] complaints + learning + advance : 88041.54 ±0%
# [29] complaints + raises + critical : 16913.97 ±0%
# [30] complaints + raises + advance : 20641.12 ±0%
# [31] complaints + critical + advance : 15821.13 ±0%
# [32] privileges + learning + raises : 39.02367 ±0%
# [33] privileges + learning + critical : 16.31662 ±0%
# [34] privileges + learning + advance : 38.47538 ±0%
# [35] privileges + raises + critical : 10.2571 ±0%
# [36] privileges + raises + advance : 28.71547 ±0%
# [37] privileges + critical + advance : 0.6451984 ±0%
# [38] learning + raises + critical : 33.4565 ±0%
# [39] learning + raises + advance : 313.3885 ±0%
# [40] learning + critical + advance : 35.08545 ±0%
# [41] raises + critical + advance : 13.36863 ±0%
# [42] complaints + privileges + learning + raises : 13707.57 ±0%
# [43] complaints + privileges + learning + critical : 13604.67 ±0%
# [44] complaints + privileges + learning + advance : 24141.94 ±0%
# [45] complaints + privileges + raises + critical : 4737.976 ±0%
# [46] complaints + privileges + raises + advance : 5450.104 ±0%
# [47] complaints + privileges + critical + advance : 4280.874 ±0%
# [48] complaints + learning + raises + critical : 10520.62 ±0%
# [49] complaints + learning + raises + advance : 23334 ±0%
# [50] complaints + learning + critical + advance : 22169.3 ±0%
# [51] complaints + raises + critical + advance : 5302.837 ±0%
# [52] privileges + learning + raises + critical : 15.13044 ±0%
# [53] privileges + learning + raises + advance : 134.4662 ±0%
# [54] privileges + learning + critical + advance : 20.64672 ±0%
# [55] privileges + raises + critical + advance : 11.7191 ±0.01%
# [56] learning + raises + critical + advance : 106.5402 ±0%
# [57] complaints + privileges + learning + raises + critical : 3826.973 ±0%
# [58] complaints + privileges + learning + raises + advance : 7144.828 ±0%
# [59] complaints + privileges + learning + critical + advance : 6926.973 ±0%
# [60] complaints + privileges + raises + critical + advance : 1608.397 ±0%
# [61] complaints + learning + raises + critical + advance : 6461.751 ±0%
# [62] privileges + learning + raises + critical + advance : 51.02312 ±0%
# [63] complaints + privileges + learning + raises + critical + advance : 2211.912 ±0%
#
# Against denominator:
# Intercept only
# ---
# Bayes factor type: BFlinearModel, JZS
Savage-Dickey法
これはMCMCを使った推定手法のひとつです。Savage-Dickey法についてはこちらを参照のこと。ネストされたモデルに対してのみ有効な手法です。
stanコードは以下のとおり。
// for Savage-Dickey method with cmdstanr
data {
int<lower=0> N, p_gamma;
// N : sample size p_gamma : number of predictors
vector[N] Y; // responce_variable
matrix[N, p_gamma] X_gamma; // predictors
real Jeffreys_alpha, Jeffreys_beta, r;
//r : scale of the Cauchy
//Jeffreys_alpha : mean of prior(sigma^2) (sufficiently small values)
//Jeffreys_beta : variance of prior(sigma^2) (sufficiently small values)
}
parameters {
vector[p_gamma] beta_gamma;
real<lower=0> sigma, g;
real alpha;
// sigma : standard deviation of Y
// g : variance of the standardized slope
// alpha : coefficient
}
transformed parameters{
vector[N] mu;
mu = X_gamma * beta_gamma;
}
model {
//model
target += normal_lpdf(Y - rep_vector(alpha, N) | mu, sigma);
// prior
target += multi_normal_lpdf( beta_gamma | rep_vector(0, p_gamma), N * g * sigma^2 * inverse(crossprod(X_gamma)));
target += inv_gamma_lpdf(g | 0.5, r*0.5 );
target += gamma_lpdf(sigma^2 | Jeffreys_alpha, Jeffreys_beta);
target += normal_lpdf(alpha | 0, 100);
}
generated quantities{
vector[p_gamma] mean_beta;
matrix[p_gamma,p_gamma] covariance_beta;
//mean_beta : mean of CMDE(beta_gamma)
//covariance_beta : covariance of CMDE(beta_gamma)
real logPostDensBeta, logPriorDensBeta;
//logPriorDensBeta:p(beta_gamma=0_vector | prior)
//logPostDensBeta:p(beta_gamma=0_vector | posterior)
real BF_01, BF_10;
//BF_01 : BF[M_N:M_gamma], BF_10 : BF[M_gamma:M_N]
matrix[p_gamma,p_gamma] ScaleMatrix_PriorDensBeta = N * r * sigma^2 * inverse_spd(crossprod(X_gamma));
//ScaleMatrix_PriorDensBeta:covariance of beta_gamma at prior
covariance_beta = (sigma^2 * ((N * g)^(-1) + 1)^(-1) ) * inverse_spd(crossprod(X_gamma));
mean_beta = sigma^(-2) * covariance_beta * X_gamma' * (Y - rep_vector(alpha, N));
logPostDensBeta = - 0.5 * p_gamma * (log(2) + log(pi())) - 0.5 * log_determinant(covariance_beta) - 0.5 * mean_beta' * inverse_spd(covariance_beta) * mean_beta;
logPriorDensBeta = lgamma((1+p_gamma)*0.5) - ((p_gamma + 1) * 0.5) * log(pi()) - 0.5 * log_determinant(ScaleMatrix_PriorDensBeta);
BF_01 = exp(logPostDensBeta - logPriorDensBeta);
BF_10 = exp(logPriorDensBeta - logPostDensBeta);
}
これを走らせたい場合は単純に以下のようにすれば走ります。
# data preparation ---------------------------------------------------------------
response_variable <- d$rating
predictors <- d %>% select(-rating) %>% mutate_all(center) %>% as_tibble()
# run_model ---------------------------------------------------------------
library(cmdstanr)
library(rstan)
library(tidyverse)
data <- list(Y=response_variable, X_gamma=predictors, r=sqrt(2)/4, N=length(attitude$rating),
p_gamma=ncol(predictors), Jeffreys_alpha=1e-5, Jeffreys_beta=1e-5)
# cmdstan : for savage-dickey method
model1 <- cmdstan_model(paste0(getwd(),"/model/model1.stan"))
# 単純にsavage-dickey法でベイズファクターを1回だけ計算したいときは下のようにする
fit_c <- model1$sample(
data = data,
parallel_chains = 4,
chains = 4,
iter_warmup = 1000,
iter_sampling = 4000,
refresh = 0,
)
ただ、今回の例では試行ごとに得られるベイズファクターがかなりかわってくるので、各モデルに対し100回ずつ総当たり(Bruteforce)でシミュレーションしてみます。
# BruteForce Regression(Savage-Dickey method) ---------------------------------------------------
# 準備
subsetfun <- function(kosuu){
XX <- matrix(,2^kosuu,kosuu)
for(i in 1:kosuu){
CCC <- t(rbind(rep(F,2^(kosuu-i)),rep(T,2^(kosuu-i))))
XX[,i] <- rep(CCC,2^(i-1))
}
return(XX)
}
# 設定
trials <- 100
subset <- subsetfun(ncol(predictors))
fit_BF <- matrix(nrow=trials, ncol=nrow(subset), data=NA)
fit_BF_name <- matrix(nrow=trials, ncol=nrow(subset), data=NA)
progress_bar_j <- txtProgressBar(min=1, max=trials, style=3)
progress_bar_i <- txtProgressBar(min=2, max=nrow(subset), style=1)
# BruteForce
for(i in 2:nrow(subset)){
setTxtProgressBar(pb=progress_bar_i, value=i)
data <- list(Y=response_variable, X_gamma=predictors[,subset[i,]], r=sqrt(2)/4, N=length(d$rating),
p_gamma= ncol(predictors[,subset[i,]]), Jeffreys_alpha=1e-5, Jeffreys_beta=1e-5)
for(j in 1:trials){
setTxtProgressBar(pb=progress_bar_j, value=j)
fit_c <- model1$sample(
data = data,
chains = 4,
iter_warmup = 1000,
iter_sampling = 4000,
refresh = 0
)
stanfit1 <- rstan::read_stan_csv(fit_c$output_files())
res1 <- as.matrix(stanfit1) %>% as.data.frame() %>% select(starts_with("BF")) %>%
summarise_all(list(mean = mean)) %>% summarise_all(list(round),digits=8) %>% pivot_longer(everything())
fit_BF[j,i] <- 1/min(res1$value)
fit_BF_name[j,i] <- res1$name[which.max(res1$value)]
rm(fit_c)
}
save(fit_BF, fit_BF_name, file = paste0(getwd(),"/data/Data_model1.RData"))
}
# comparing plot(Savage-Dickey method VS Gaussian Quadrature) ----------------------------------------------------------
fit_BF_rate <- matrix(nrow=trials, ncol=nrow(subset), data=NA)
for(i in 2:nrow(subset)){
## MCMCの採用BF(generated quantities)がすべてBF_01(or BF_10)で一意でない場合停止
if(nrow(unique(fit_BF_name)) != 1)
stop("MCMCの採用BF(generated quantities)がすべてBF_01(or BF_10)で一意でない")
## regressionBF からの結果抽出用ID
id <- paste(attributes(predictors)$names[subset[i,]], collapse = " + ")
## MCMCの結果がBF_01のとき、BF_10に修正
if(unique(fit_BF_name[1,i]) == "BF_01_mean")
fit_BF_rate[,i] <- 1 / fit_BF[,i]
else
fit_BF_rate[,i] <- fit_BF[,i]
## MCMCの採用BF(generated quantities)をregressionBFの結果で評価
fit_BF_rate[,i] <- fit_BF_rate[,i] /
exp(attributes(bf[id])$bayesFactor[["bf"]])
}
fit_BF_rate <- fit_BF_rate %>% as_tibble()
ID <- c()
for(i in 2:nrow(subset)){
id <- paste(attributes(predictors)$names[subset[i,]], collapse = " + ")
ID[i-1] <- exp(attributes(bf[id])$bayesFactor[["bf"]])
names(ID)[i-1] <- id
}
for(i in 1:nrow(subset)){
colnames(fit_BF_rate)[i] <- paste(attributes(predictors)$names[subset[i,]], collapse = " + ")
}
fit_BF_rate %>% select(-"") %>% summarise_all(list(median)) %>% summarise_all(list(round),digits=2) %>% pivot_longer(everything()) %>%
mutate(BF = ID[name]) %>% mutate_at(vars(BF), round, digits=3) %>% mutate_at(vars(BF),as.factor) -> fit_BF_median
fit_BF_rate %>% select(-"") %>% pivot_longer(cols = everything(),names_to = "name", values_to = "fit_BF") %>%
mutate(BF = ID[name]) %>% mutate_at(vars(BF), round, digits=3) %>% mutate_at(vars(BF),as.factor) -> res
library(RColorBrewer)
mycol <- c(rep(c(brewer.pal(12,"Set3"),brewer.pal(8,"Set2"),brewer.pal(9,"Set1")),2),brewer.pal(6,"Set3"))
p <- ggplot(data=res, aes(x=BF,y=fit_BF, fill = name)) + theme_light(base_size=11) + geom_boxplot() +
geom_text(data=fit_BF_median, aes(x=BF, y=value, label=value), nudge_y = 0.08, color="grey20", size=3.4) +
theme(legend.position = "") + scale_y_log10(limits = c(1e-02, 1e+02)) + theme(axis.text.x = element_text(angle = 90, hjust = 1)) +
scale_fill_manual(values = mycol) + xlab("BF10 by Gaussian quadrature") + ylab("BF10 by Savage-Dickey method")
BayesFactor::regressionBF()
との比較結果を示します。
結果、BayesFactor::regressionBF()
による結果から大きく逸脱してはいませんが、BFの値が1000を超えたあたりからMCMC毎の結果のばらつきが大きくなっています。また全体的にBFの値が大きくなるにつれてBFがヌルモデルを過小評価する傾向があります。
BridgeSampling法
Bridgesampling法も推定手法のひとつです。モデルの自由エネルギーをMCMC結果から直接推定する手法のようですが、詳細はまだ勉強してません!
各モデルとヌルモデルの自由エネルギーを推定し、比を取ることでベイズファクターを推定します。
各モデルの自由エネルギーを推定するためのMCMCを得るStanコードは下記の通り。
data {
int<lower=0> N;
int<lower=0> p_gamma;
// N : sample size p_gamma : number of predictors
vector[N] Y; // responce_variable
matrix[N, p_gamma] X_gamma; // predictors
real Jeffreys_alpha;
real Jeffreys_beta;
real r;
//r : scale of the Cauchy
//Jeffreys_alpha : mean of prior(sigma^2) (sufficiently small values)
//Jeffreys_beta : variance of prior(sigma^2) (sufficiently small values)
}
parameters {
vector[p_gamma] beta_gamma;
real<lower=0> sigma;
real g;
real alpha;
// sigma : standard deviation of Y
// g : variance of the standardized slope
// alpha : coefficient
}
transformed parameters{
vector[N] mu;
mu = X_gamma * beta_gamma;
}
model {
//model
target += normal_lpdf(Y - rep_vector(alpha, N) | mu, sigma);
// prior
target += multi_normal_lpdf( beta_gamma | rep_vector(0, p_gamma), N * g * sigma^2 * inverse(crossprod(X_gamma)));
target += inv_gamma_lpdf(g | 0.5, r*0.5 );
target += gamma_lpdf(sigma^2 | Jeffreys_alpha, Jeffreys_beta);
target += normal_lpdf(alpha | 0, 100);
}
ヌルモデルの自由エネルギーを推定するためのMCMCを得るStanコードは下記の通り。
data {
int<lower=0> N;
vector[N] Y; // responce_variable
real Jeffreys_alpha;
real Jeffreys_beta;
//Jeffreys_alpha : mean of prior(sigma^2) (sufficiently small values)
//Jeffreys_beta : variance of prior(sigma^2) (sufficiently small values)
}
parameters {
real<lower=0> sigma;
real alpha;
// sigma : standard deviation of Y
// g : variance of the standardized slope
// alpha : coefficient
}
model {
//model
target += normal_lpdf(Y - rep_vector(alpha, N) | 0, sigma);
// prior
target += gamma_lpdf(sigma^2 | Jeffreys_alpha, Jeffreys_beta);
target += normal_lpdf(alpha | 0, 100);
}
こちらも総当たりで各モデルに対し11回ずつでシミュレーションしてみます。
# BruteForce Regression(bridge-sampling) -----------------------------
# rstan : for bridge-sampling
model2 <- stan_model(paste0(getwd(),"/model/model2.stan"))
model3 <- stan_model(paste0(getwd(),"/model/model3.stan"))
library(bridgesampling)
trials <- 11
subset <- subsetfun(ncol(predictors))
free_energy_null <- rep(NA, n=trials)
free_energy_gamma <- matrix(nrow=trials, ncol=nrow(subset), data=NA)
BF_bridgesampler <- matrix(nrow=trials, ncol=nrow(subset), data=NA)
progress_bar_j <- txtProgressBar(min=1, max=trials, style=3)
progress_bar_i <- txtProgressBar(min=2, max=nrow(subset), style=1)
for(j in 1:trials){
fit_b <- sampling(
model3,
data = data,
chains = 4,
iter = 10000,
warmup = 1000,
)
free_energy_null[j] <- bridge_sampler(fit_b, method="warp3")$logml
rm(fit_b)
save(free_energy_null, file = paste0(getwd(),"/data/free_energy_null.RData"))
}
for(i in 2:nrow(subset)){
setTxtProgressBar(pb=progress_bar_i, value=i)
data <- list(Y=response_variable, X_gamma=predictors[,subset[i,]], r=sqrt(2)/4, N=length(d$rating),
p_gamma= ncol(predictors[,subset[i,]]), Jeffreys_alpha=1e-5, Jeffreys_beta=1e-5)
for(j in 1:trials){
setTxtProgressBar(pb=progress_bar_j, value=j)
fit_b <- sampling(
model2,
data = data,
chains = 4,
iter = 10000,
warmup = 1000,
)
free_energy_gamma[j,i] <- bridge_sampler(fit_b, method="warp3")$logml
BF_bridgesampler[j,i] <- exp(free_energy_gamma[j,i] - free_energy_null[j])
rm(fit_b)
save(free_energy_gamma,BF_bridgesampler, file = paste0(getwd(),"/data/BF_bridgesampler.RData"))
}
}
# comparing plot(Bridge-sampling method VS Gaussian Quadrature) ----------------------------------------------------------
fit_BF_rate <- matrix(nrow=trials, ncol=nrow(subset), data=NA)
for(i in 2:nrow(subset)){
## regressionBF からの結果抽出用ID
id <- paste(attributes(predictors)$names[subset[i,]], collapse = " + ")
fit_BF_rate[,i] <- BF_bridgesampler[,i] /
exp(attributes(bf[id])$bayesFactor[["bf"]])
}
fit_BF_rate <- fit_BF_rate %>% as_tibble()
ID <- c()
for(i in 2:nrow(subset)){
id <- paste(attributes(predictors)$names[subset[i,]], collapse = " + ")
ID[i-1] <- exp(attributes(bf[id])$bayesFactor[["bf"]])
names(ID)[i-1] <- id
}
for(i in 1:nrow(subset)){
colnames(fit_BF_rate)[i] <- paste(attributes(predictors)$names[subset[i,]], collapse = " + ")
}
fit_BF_rate %>% select(-"") %>% summarise_all(list(median)) %>% summarise_all(list(round),digits=2) %>% pivot_longer(everything()) %>%
mutate(BF = ID[name]) %>% mutate_at(vars(BF), round, digits=3) %>% mutate_at(vars(BF),as.factor) -> fit_BF_median
fit_BF_rate %>% select(-"") %>% pivot_longer(cols = everything(),names_to = "name", values_to = "fit_BF") %>%
mutate(BF = ID[name]) %>% mutate_at(vars(BF), round, digits=3) %>% mutate_at(vars(BF),as.factor) -> res
library(RColorBrewer)
mycol <- c(rep(c(brewer.pal(12,"Set3"),brewer.pal(8,"Set2"),brewer.pal(9,"Set1")),2),brewer.pal(6,"Set3"))
p <- ggplot(data=res, aes(x=BF,y=fit_BF, fill = name)) + theme_light(base_size=11) + geom_boxplot() +
geom_text(data=fit_BF_median, aes(x=BF, y=value, label=value), nudge_y = 0.08, color="grey20", size=3.4) +
theme(legend.position = "") + scale_y_log10(limits = c(1e-02, 1e+02)) + theme(axis.text.x = element_text(angle = 90, hjust = 1)) +
scale_fill_manual(values = mycol) + xlab("BF10 by Gaussian quadrature") + ylab("BF10 by BridgeSampling")
BayesFactor::regressionBF()
との比較結果を示します。
こちらもBayesFactor::regressionBF()
による結果から大きく逸脱していません。全体的にBFの値が大きくなるにつれてBFがヌルモデルを過小評価する傾向もSavage-Dickey法と同じです。
おそらくは、$\alpha$に分散が大きく平均$0$の正規分布を、$\sigma^2$にガンマ分布を用いて$(2)$式(Jeffrey’s prior)を近似的に表現しているので、本来のモデルとの微妙な差異が結果に影響しているものと思われます。
また、Savage-Dickey法ではBFの値が1000を超えたあたりからMCMC毎のばらつきが大きくなっていて、結果の収束度合いがゴミのようだったのに対し、BridgeSampling法ではBFがどんな値をとってもMCMC毎のばらつきがかなり小さくなっています。BridgeSampling法がいかに優れた手法か思い知らされますね…せっかく$\boldsymbol{\beta}_\boldsymbol{\gamma}$の条件付き事後分布をせっせと導出したのに残念ながら報われませんでした。
まとめ
本記事では線形回帰モデルを例にベイズファクターの計算手法を複数実践し、それぞれの結果を比較してみました。得られた知見は以下の通り。
・ Jeffrey’s priorの近似的表現は、MCMC結果から推定したBFに若干影響する(たぶん)
・ Savage-Dickey法の推定精度はガバガバ
・ BridgeSampling法は優秀
ということで、BridgeSampling法について詳しくなりたいなあと思いました。