Lasso の Post-Selection Inference

Post-Selection Inferenceの切断分布を求める処理を書いてみました。

J. Lee and Dennis L. Sun and Yuekai Sun and Jonathan E. Taylor, “Exact post-selection inference, with application to the lasso”, The Annals of Statistics, volume 44,
number=3, pages 907-927 (2016}

# X, yを中心化する関数
cent=function(z){
    if(is.matrix(z)){
        p=ncol(z)
        for(j in 1:p)z[,j]=z[,j]-mean(z[,j])
    }
    else z=z-mean(z)
  return(z)
}

# 線形回帰のLassoの係数を求める関数
lasso=function(X,y,lambda){
## X,yは、中心化されていると仮定 
    p=ncol(X)
    out=NULL
    for(j in 1:p){
        SD=sd(X[,j])
        X[,j]=X[,j]/SD
        beta=sum(X[,j]*y)
        if(beta>lambda)out=c(out,(beta-lambda)/SD)
        else if(beta< -lambda)out=c(out,(beta+lambda)/SD)
        else out=c(out,0)
    }
    return(out)
}

# 多面体で表現される区間(モデルと符号で条件付)
intervals=function(X,y,lambda,M,s,k){
    n=nrow(X)
    p=ncol(X)
    m=length(M)
    P=X[,M]%*%solve(t(X[,M])%*%X[,M])%*%t(X[,M])
    XX=X[,M]%*%solve(t(X[,M])%*%X[,M])
    A=rbind(1/lambda*t(X[,-M])%*%(diag(n)-P),
        -1/lambda*t(X[,-M])%*%(diag(n)-P),
        -diag(s)%*%solve(t(X[,M])%*%X[,M])%*%t(X[,M])
        )
    b=c(rep(1,p-m)-as.vector(t(X[,-M])%*%XX%*%s),
        as.vector(rep(1,p-m)+t(X[,-M])%*%XX%*%s),
        -as.vector(lambda*diag(s)%*%solve(t(X[,M])%*%X[,M])%*%s)
        )
    eta=(solve(t(X)%*%X)%*%(t(X)))[k,]
    cc=eta/as.vector(t(eta)%*%eta)
    z=(diag(n)-cc%*%t(eta))%*%as.matrix(y)
    Ac=as.vector(A%*%cc)
    Az=as.vector(A%*%z)
    nu.max=0
    nu.min=0
    nu.zero=0
    for(j in 1:p){
      if(Ac[j]>0)nu.max=max(nu.max,(b[j]-Az[j])/Ac[j])
      else if(Ac[j]<0)nu.min=min(nu.min,(b[j]-Az[j])/Ac[j])
      else nu.zero=min(nu.zero,b[j]-Az[j])
    }
    return(c(as.vector(t(eta)%*%as.matrix(y)),nu.min, nu.max,nu.zero))
}

# 10進数を2進数に変換する関数
binary=function(i,m){
  out=NULL
  for(j in 1:m){
    out=c(2*(i%%2)-1,out)
    i=i%/%2
  }
  return(out)
}

# モデルと符号に関する多面体の区間を、符号を動かして、合併させている。
bind.intervals=function(u,v){
  p=length(u); q=length(v)
  u=c(u,Inf); v=c(v,Inf)
  u.state=1; v.state=1
  w=NULL
  i=1; j=1
  while(i<=p||j<=q){
    if(u[i]<v[j]){
      if(v.state==1)w=c(w,u[i])
      i=i+1
      u.state=-u.state
    }
    else if(u[i]>v[j]){
      if(u.state==1)w=c(w,v[j])
      j=j+1
      v.state=-v.state
    }
    else {
      if(i!=p+1&&u.state==v.state)w=c(w,v[j])
      i=i+1; j=j+1
      u.state=-u.state; v.state=-v.state
    }
  }
  return(w)
}
bind.intervals(u,v)

##  実行例
# データ生成
n=100
p=5
X=matrix(rnorm(n*p),n,p)
y=X[,1]*2-X[,2]*3+rnorm(n)*0.4
# 中心化して、Lassoの係数を求め、アクティブ集合とそれぞれの符号を求める
X=cent(X)
y=cent(y)
lambda=40
beta=lasso(X,y,lambda)
M=NULL
ss=NULL
for(j in 1:p){
  if(beta[j]>0){M=c(M,j); ss=c(ss,1)}
  else if(beta[j]<0){M=c(M,j); ss=c(ss,-1)}
}
# 多面体の区間を求める
m=length(M)
L=2^m
print(intervals(X,y,lambda,M,ss,1)[1])
print(M)
print(ss)
S=NULL
for(i in 1:L){
  s=binary(i,m)
  u=intervals(X,y,lambda,M,s,1)
  print(s)
  print(u[2:3])
  S=bind.intervals(S,u[2:3])
}
S
pnorm(-1)