Post-Selection Inferenceの切断分布を求める処理を書いてみました。
J. Lee and Dennis L. Sun and Yuekai Sun and Jonathan E. Taylor, “Exact post-selection inference, with application to the lasso”, The Annals of Statistics, volume 44,
number=3, pages 907-927 (2016}
# X, yを中心化する関数
cent=function(z){
if(is.matrix(z)){
p=ncol(z)
for(j in 1:p)z[,j]=z[,j]-mean(z[,j])
}
else z=z-mean(z)
return(z)
}
# 線形回帰のLassoの係数を求める関数
lasso=function(X,y,lambda){
## X,yは、中心化されていると仮定
p=ncol(X)
out=NULL
for(j in 1:p){
SD=sd(X[,j])
X[,j]=X[,j]/SD
beta=sum(X[,j]*y)
if(beta>lambda)out=c(out,(beta-lambda)/SD)
else if(beta< -lambda)out=c(out,(beta+lambda)/SD)
else out=c(out,0)
}
return(out)
}
# 多面体で表現される区間(モデルと符号で条件付)
intervals=function(X,y,lambda,M,s,k){
n=nrow(X)
p=ncol(X)
m=length(M)
P=X[,M]%*%solve(t(X[,M])%*%X[,M])%*%t(X[,M])
XX=X[,M]%*%solve(t(X[,M])%*%X[,M])
A=rbind(1/lambda*t(X[,-M])%*%(diag(n)-P),
-1/lambda*t(X[,-M])%*%(diag(n)-P),
-diag(s)%*%solve(t(X[,M])%*%X[,M])%*%t(X[,M])
)
b=c(rep(1,p-m)-as.vector(t(X[,-M])%*%XX%*%s),
as.vector(rep(1,p-m)+t(X[,-M])%*%XX%*%s),
-as.vector(lambda*diag(s)%*%solve(t(X[,M])%*%X[,M])%*%s)
)
eta=(solve(t(X)%*%X)%*%(t(X)))[k,]
cc=eta/as.vector(t(eta)%*%eta)
z=(diag(n)-cc%*%t(eta))%*%as.matrix(y)
Ac=as.vector(A%*%cc)
Az=as.vector(A%*%z)
nu.max=0
nu.min=0
nu.zero=0
for(j in 1:p){
if(Ac[j]>0)nu.max=max(nu.max,(b[j]-Az[j])/Ac[j])
else if(Ac[j]<0)nu.min=min(nu.min,(b[j]-Az[j])/Ac[j])
else nu.zero=min(nu.zero,b[j]-Az[j])
}
return(c(as.vector(t(eta)%*%as.matrix(y)),nu.min, nu.max,nu.zero))
}
# 10進数を2進数に変換する関数
binary=function(i,m){
out=NULL
for(j in 1:m){
out=c(2*(i%%2)-1,out)
i=i%/%2
}
return(out)
}
# モデルと符号に関する多面体の区間を、符号を動かして、合併させている。
bind.intervals=function(u,v){
p=length(u); q=length(v)
u=c(u,Inf); v=c(v,Inf)
u.state=1; v.state=1
w=NULL
i=1; j=1
while(i<=p||j<=q){
if(u[i]<v[j]){
if(v.state==1)w=c(w,u[i])
i=i+1
u.state=-u.state
}
else if(u[i]>v[j]){
if(u.state==1)w=c(w,v[j])
j=j+1
v.state=-v.state
}
else {
if(i!=p+1&&u.state==v.state)w=c(w,v[j])
i=i+1; j=j+1
u.state=-u.state; v.state=-v.state
}
}
return(w)
}
bind.intervals(u,v)
## 実行例
# データ生成
n=100
p=5
X=matrix(rnorm(n*p),n,p)
y=X[,1]*2-X[,2]*3+rnorm(n)*0.4
# 中心化して、Lassoの係数を求め、アクティブ集合とそれぞれの符号を求める
X=cent(X)
y=cent(y)
lambda=40
beta=lasso(X,y,lambda)
M=NULL
ss=NULL
for(j in 1:p){
if(beta[j]>0){M=c(M,j); ss=c(ss,1)}
else if(beta[j]<0){M=c(M,j); ss=c(ss,-1)}
}
# 多面体の区間を求める
m=length(M)
L=2^m
print(intervals(X,y,lambda,M,ss,1)[1])
print(M)
print(ss)
S=NULL
for(i in 1:L){
s=binary(i,m)
u=intervals(X,y,lambda,M,s,1)
print(s)
print(u[2:3])
S=bind.intervals(S,u[2:3])
}
S
pnorm(-1)