正則化(Ridge回帰とLasso回帰)について

はじめに

今回は、回帰分析において過学習を防いだり、多重共線性に対応したりするために使われるRidge回帰や、変数の数が標本数より多いような時に変数選択の方法として使われるLasso回帰について、理論を整理しようと思います。

これら2つの手法ともには正則化という手法で説明されるものです。まずは正則化の観点からRidge回帰とLasso回帰を理解し、さらにこれらの手法が確率モデルでも説明可能であることを示そうと思います。

確率モデルで説明可能な手法であるということは、データセットとモデルから尤度を計算しベイズ推定可能であるということなので、ゆくゆくはRstanでこれら2つの手法を実装してみたいと思っています。

本記事の構成は以下の通りです。

参考にした本や記事は以下のとおりです。

線形回帰の導入

以下の線形回帰を考えます。

(1)y^=w0+w1ϕ1(x)+,wHϕH(x)

ここでϕ(x)xの関数です。

これを行列で表現すると、

(2)(y^1y^2y^N)=(ϕ0(x1)ϕ1(x1)ϕH(x1)ϕ0(x2)ϕ1(x2)ϕH(x2)ϕ0(xN)ϕ1(xN)ϕH(xN))(w0w1wH)

ただし、ϕ0(xn)=1です。

これを行列を使って以下のように表記します。

(3)y^=Φw

一般にΦ計画行列と呼びます。またϕ0(x),,ϕH(x)基底関数ϕ(x)=(ϕ0(x),ϕ1(x),,ϕH(x))Tx特徴ベクトルと呼びます。

memo

基底関数を、ベクトルxをスカラーに変換する解析的な関数であると想定していますが、 通常の重回帰モデルもϕ(x)=xとすれば(3)式で説明できます!その場合基底関数は各変数それぞれを出力する関数になります。

このとき、以下の公式が得られます。

◆線形回帰モデルの解 線形回帰モデルy^=Φwにおけるwの最小二乗解は (4)w=(ΦTΦ)1ΦTy


◆証明 データ全体の誤差(損失関数)は、 (5)E=n=1N(ynϕ(xn)w)2=(yΦw)T(yΦw)=yTy2wT(ΦTy)+wTΦΦw これをwで微分すると、 (6)Ew=2ΦTy+2ΦTΦw よって、Eを最小化するwは、Ew=0より下記の通り得られる。 w=(ΦTΦ)1ΦTy

正則化

線形回帰モデルの係数パラメータは(4)式で計算できることが分かりましたが、右辺にはΦTΦの逆行列が存在するため、ΦTΦが正則でない場合、係数パラメータの解を得ることが出来ません。

memo

n次正方行列Aについて、 AB=BA=I となるn次正方行列Bが存在するとき、Aは正則であるという。

では、ΦTΦが正則でないのは具体的にどんな場合かとうと、ϕi(x)=ϕj(x)、もしくはϕi(x)=αϕj(x)  (αR)となるようなϕ(x)の組が存在する場合が典型的です。このような場合、ϕi(x)ϕj(x)の相関は1または-1になっています。またϕ(xi)αϕ(xj)だったりする場合では、ΦTΦ1の要素が大きくなり、結果的にwの絶対値が大きくなってしまい、過学習が生じるようです。このような場合では、ϕi(x)ϕj(x)はかなり強い相関関係になっています1

memo

ϕi(x)=αϕj(x)  (αR)となるようなϕ(x)の組が存在する場合、 rank(ΦTΦ)<H と行列ΦTΦはランク落ちしており、ランク落ちした行列は逆行列が存在しない。

係数パラメータの絶対値が大きくなるのを避けるための工夫が正則化です。正則化では、(5)式で得られたデータ全体の誤差に、wの絶対値が大きくなることによるペナルティを課すように設定します。 このペナルティの設定方法が、Ridge回帰とLasso回帰の唯一の違いです。

◆Ridge回帰とLasso回帰の損失関数 Ridge回帰における損失関数は、 (7)E=(yΦw)T(yΦw)+α||w||22 Lasso回帰における損失関数は、 (8)E=(yΦw)T(yΦw)+α||w||11 ただし、α0


memo

||x||pはベクトルのpノルムといい ||x||p=|x1|p+,|xn|pp を意味する。定義より||x||11xのマンハッタン距離を、||x||22xのユークリッド距離を意味する。

損失関数を上記のように設定することがΦTΦが正則でない場合に有効である仕組みを把握するために、Ridge回帰の場合についてみていきます。

(7)式をwで微分すると、

(9)Ew=2ΦTy+2ΦTΦw+2αw

よって、Eを最小化するwは、Ew=0より下記のとおり得られます。

(10)w=(ΦTΦ+αI)1ΦTy

リッジ回帰は、もともとは(7)式のように損失関数にwの絶対値を小さくする為の項を設定したものでしたが、結果的にΦTΦの対角要素に微小量αを足すことで、逆行列計算の対象が正則行列であることを確実にし、計算を安定化させていることが分かります。

memo

任意の行列ARm×n、ベクトルxRm0について、 xT(AAT)x=(ATx)T(ATx)=||ATx||220 だから、AATは半正定値行列。 また xT(αI)x=α||x||22>0 だから、αIは正定値行列。このとき xT(AAT+αI)x=||ATx||22+α||x||22>0 よってATA+αIは正定値行列であり、正則行列。

Lasso回帰については同様の方法では正則化との関係が見えてきませんでした。ただ係数パラメータの多くが0になる仕組みについてはこちらの記事に詳しかったです。

確率モデルでRidge回帰を捉える

前節で説明したRidge回帰は確率分布を用いた以下のモデルでも説明できます。

◆Ridge回帰の確率モデル (7)のRidge回帰と等価の確率モデルは、 (11){p(y|w,x)=Normal(y|wTϕ(x),σ2)p(w)=Normal(w|0,ρ1IH) ただし、α=ρσ2


◆証明 ywの同時確率密度について以下の通り仮定する。 p(y,w|x)=p(y|w,x)p(w) これをwの関数F(w)とみたとき、独立同分布のデータセットyX=(x1,,xN)Tがすべて観測された時のF(w)を最大化するwが最尤推定解である。 arg maxwRH+1logF(w)=arg maxwRH+1logn=1Np(yn|w,xn)p(w)=arg maxwRH+1[n=1NlogNormal(yn|wTϕ(x),σ2)+logNormal(w|0,ρ1IH)]=arg maxwRH+1[logMultiNormal(y|Φw,σ2IN)+logNormal(w|0,ρ1IH)]=arg maxwRH+1[12σ2(yΦw)T(yΦw)ρ2wTw+C]=arg minwRH+1[(yΦw)T(yΦw)+ρσ2||w||22]=arg minwRH+1[(yΦw)T(yΦw)+α||w||22] これは(7)式のRidge回帰における損失関数Ewについての最小化である。

確率モデルでLasso回帰を捉える

Lasso回帰は確率分布を用いた以下のモデルでも説明できます。

◆Lasso回帰の確率モデル (8)のLasso回帰と等価の確率モデルは、 (12){p(y|w,x)=Normal(y|wTϕ(x),σ2)p(w)=h=1HLaplace(wh|0,b) ただし、α=2σ2b


◆証明 ywの同時確率密度について以下の通り仮定する。 p(y,w|x)=p(y|w,x)p(w) これをwの関数F(w)とみたとき、独立同分布のデータセットyX=(x1,,xN)Tがすべて観測された時のF(w)を最大化するwが最尤推定解である。 arg maxwRH+1logF(w)=arg maxwRH+1logn=1Np(yn|w,xn)p(w)=arg maxwRH+1[n=1NlogNormal(yn|wTϕ(x),σ2)+h=1HlogLaplace(wh|0,b)]=arg maxwRH+1[logMultiNormal(y|Φw,σ2IN)+h=1HlogLaplace(wh|0,b)]=arg maxwRH+1[12σ2(yΦw)T(yΦw)1bh=1H|wh|+C]=arg minwRH+1[(yΦw)T(yΦw)+2σ2b||w||11]=arg minwRH+1[(yΦw)T(yΦw)+α||w||11] これは(8)式のLasso回帰における損失関数Ewについての最小化である。

memo

確率密度p(x)p(x)=12bexp[|xμ|b] である確率分布をラプラス分布といい、確率変数Xがラプラス分布に従うことを XLaplace(μ,b) と表す。

蛇足

今回の記事はHugoのshortcode機能を駆使して文章に枠をもたせてみました。今までの記事よりかなり見やすくなっていませんか??しばらくはこのスタイルで行こうと思います。

最後に今回の記事のイメージソングをあげておきます。ブログ名の由来になった曲です。


  1. 回帰分析で生じるこのような問題は多重共線性と呼ばれています。 ↩︎

コメントを書く


※ コメントは承認されると表示されます

承認されたコメント一覧