第2章 一般化線形回帰


第1章より

soft.th <- function(lambda, x) {
  return(sign(x) * pmax(abs(x) - lambda, 0))
}
linear.lasso <- function(X, y, lambda = 0, beta = rep(0, ncol(X))) {
  n <- nrow(X)
  p <- ncol(X)
  res <- centralize(X, y)   ## 中心化(下記参照)
  X <- res$X
  y <- res$y
  eps <- 1
  beta.old <- beta
  while (eps > 0.001) {    ## このループの収束を待つ
    for (j in 1:p) {
      r <- y - as.matrix(X[, -j]) %*% beta[-j]
      beta[j] <- soft.th(lambda, sum(r * X[, j]) / n) / (sum(X[, j] * X[, j]) / n) 
    }
    eps <- max(abs(beta - beta.old))
    beta.old <- beta
  }
  beta <- beta / res$X.sd   ## 各変数の係数を正規化前のものに戻す
  beta.0 <- res$y.bar - sum(res$X.bar * beta)
  return(list(beta = beta, beta.0 = beta.0))
}
centralize <- function(X, y, standardize = TRUE) {
  X <- as.matrix(X)
  n <- nrow(X)
  p <- ncol(X)
  X.bar <- array(dim = p)          ## Xの各列の平均
  X.sd <- array(dim = p)           ## Xの各列の標準偏差
  for (j in 1:p) {
    X.bar[j] <- mean(X[, j])
    X[, j] <- (X[, j] - X.bar[j])  ## Xの各列の中心化
    X.sd[j] <- sqrt(var(X[, j]))
    if (standardize == TRUE) 
      X[, j] <- X[, j] / X.sd[j]   ## Xの各列の標準化
  }
  if (class(y) == "matrix") {      ## yが行列の場合
    K <- ncol(y)
    y.bar <- array(dim = K)        ## yの平均
    for (k in 1:K) {
      y.bar[k] <- mean(y[, k])
      y[, k] <- y[, k] - y.bar[k]  ## yの中心化
    }
  } else {                         ## yがベクトルの場合
    y.bar <- mean(y)
    y <- y - y.bar
  }
  return(list(X = X, y = y, X.bar = X.bar, X.sd = X.sd, y.bar = y.bar))
}

2.1 線形回帰のLassoの一般化

W.linear.lasso <- function(X, y, W, lambda = 0) {
  n <- nrow(X)
  p <- ncol(X)
  X.bar <- array(dim = p)
  for (k in 1:p) {
    X.bar[k] <- sum(W %*% X[, k]) / sum(W)
    X[, k] <- X[, k] - X.bar[k]
  }
  y.bar <- sum(W %*% y) / sum(W)
  y <- y - y.bar
  L <- chol(W)
  # L <- sqrt(W)
  u <- as.vector(L %*% y)
  V <- L %*% X
  beta <- linear.lasso(V, u, lambda)$beta
  beta.0 <- y.bar - sum(X.bar * beta)
  return(c(beta.0, beta))
}

2.2 2値のロジスティック回帰

例11

f <- function(x) {
  return(exp(beta.0 + beta * x) / (1 + exp(beta.0 + beta * x)))
}
beta.0 <- 0
beta.seq <- c(0, 0.2, 0.5, 1, 2, 10)
m <- length(beta.seq)
beta <- beta.seq[1]
plot(f, xlim = c(-10, 10), ylim = c(0, 1), xlab = "x", ylab = "y",
     col = 1, main = "ロジスティック曲線")
for (i in 2:m) {
  beta <- beta.seq[i]
  par(new = TRUE)
  plot(f, xlim = c(-10, 10), ylim = c(0, 1), xlab = "", ylab = "", axes = FALSE, col = i)
}
legend("topleft", legend = beta.seq, col = 1:length(beta.seq), lwd = 2, cex = .8)

par(new = FALSE)

例12

## データ生成
N <- 1000
p <- 2
X <- matrix(rnorm(N * p), ncol = p)
X <- cbind(rep(1, N), X)
beta <- rnorm(p + 1)
y <- array(N)
s <- as.vector(X %*% beta)
prob <- 1 / (1 + exp(s)) 
for (i in 1:N) {
  if (runif(1) > prob[i]) {
    y[i] <- 1
  } else {
    y[i] <- -1
  }
}
beta
## [1] -0.8251151 -0.6041347  2.8357392
## 最尤推定値の計算
beta <- Inf
gamma <- rnorm(p + 1)
while (sum((beta - gamma) ^ 2) > 0.001) {
  beta <- gamma
  s <- as.vector(X %*% beta)
  v <- exp(-s * y)
  u <- y * v / (1 + v)
  w <- v / (1 + v) ^ 2
  z <- s + u / w
  W <- diag(w)
  gamma <- as.vector(solve(t(X) %*% W %*% X) %*% t(X) %*% W %*% z)          ##
  print(gamma)
}
## [1] -1.09048865 -0.01923956  2.11691931
## [1] -0.7567999 -0.5396643  2.5134395
## [1] -0.7908441 -0.6295797  2.8209931
## [1] -0.7982177 -0.6412975  2.8641928
## [1] -0.7983656 -0.6414807  2.8649043
beta  ## 真の値。最尤法でこの値を推定したい
## [1] -0.7982177 -0.6412975  2.8641928
logistic.lasso <- function(X, y, lambda) {
  p <- ncol(X)
  beta <- Inf
  gamma <- rnorm(p)
  while (sum((beta - gamma) ^ 2) > 0.01) {
    beta <- gamma
    s <- as.vector(X %*% beta)
    v <- as.vector(exp(-s * y))
    u <- y * v / (1 + v)
    w <- v / (1 + v) ^ 2
    z <- s + u / w
    W <- diag(w)
    gamma <- W.linear.lasso(X[, 2:p], z, W, lambda = lambda)
    print(gamma)
  }
  return(gamma)
}

例13

N <- 1000
p <- 2
X <- matrix(rnorm(N * p), ncol = p)
X <- cbind(rep(1, N), X)
beta <- rnorm(p + 1)
y <- array(N)
s <- as.vector(X %*% beta)
prob <- 1 / (1 + exp(s))
for (i in 1:N) {
  if (runif(1) > prob[i]) {
    y[i] <- 1
  } else {
    y[i] <- -1
  }
}
#logistic.lasso(X, y, 0)
logistic.lasso(X, y, 0.1)
## [1] -0.1179388  1.7995542 -0.7891112
## [1] -0.1239010  1.3534193  0.2157502
## [1] -0.13270237  1.52965899  0.01516135
## [1] -0.136756435  1.584675126  0.009576722
## [1] -0.136756435  1.584675126  0.009576722
logistic.lasso(X, y, 0.2)
## [1] 0.1158749 1.3300292 0.0000000
## [1] -0.1216319  1.2138788  0.0000000
## [1] -0.1118214  1.2051842  0.0000000
## [1] -0.1118214  1.2051842  0.0000000

例14

## データ生成
N <- 1000
p <- 2
X <- matrix(rnorm(N * p), ncol = p)
X <- cbind(rep(1, N), X)
beta <- 10 * rnorm(p + 1)
y <- array(N)
s <- as.vector(X %*% beta)
prob <- 1 / (1 + exp(s)) 
for (i in 1:N) {
  if (runif(1) > prob[i]) {
    y[i] <- 1
  } else {
    y[i] <- -1
  }
}

## パラメータ推定
beta.est <- logistic.lasso(X, y, 0.1)
## [1]  1.001715 -1.602514  0.000000
## [1]  1.3120540 -2.3016201 -0.2229933
## [1]  1.5723713 -2.8611635 -0.3413522
## [1]  1.7273824 -3.1952329 -0.4064963
## [1]  1.7918640 -3.3358622 -0.4353427
## [1]  1.8119803 -3.3802578 -0.4451937
## 分類処理
for (i in 1:N) {
  if (runif(1) > prob[i]) {
    y[i] <- 1
  } else {
    y[i] <- -1
  }
}
z <- sign(X %*% beta.est)  ## 指数部が正なら+1, 負なら-1と判定する
table(y, z)
##     z
## y     -1   1
##   -1 266  62
##   1   37 635

例15

library(glmnet)
## Warning: package 'glmnet' was built under R version 4.0.3
## Loading required package: Matrix
## Loaded glmnet 4.0-2
df <- read.csv("breastcancer.csv")
## ファイル breastcancer.csv をカレントディレクトリにおく
x <- as.matrix(df[, 1:1000])
y <- as.vector(df[, 1001])
cv <- cv.glmnet(x, y, family = "binomial")
cv2 <- cv.glmnet(x, y, family = "binomial", type.measure = "class")
par(mfrow = c(1, 2))
plot(cv)
plot(cv2)

par(mfrow = c(1, 1))
glm <- glmnet(x, y, lambda = 0.03, family = "binomial")
beta <- drop(glm$beta)
beta[beta != 0]
##   A.200053_at A.200740_s_at   A.200855_at   A.202011_at   A.202117_at 
##   0.253816593  -0.009995197   0.127286134   0.050447833   0.027646245 
##   A.203188_at   A.203287_at A.203347_s_at   A.206197_at A.208184_s_at 
##   0.022838365  -0.044863822  -0.258511276   0.044533711  -0.161215317 
## A.209608_s_at A.210137_s_at   A.212708_at   A.213285_at A.216515_x_at 
##  -0.191407462   0.209509915   0.033927033   0.016215819  -0.014640332 
## A.218080_x_at   A.218795_at A.218877_s_at A.219252_s_at A.219490_s_at 
##  -0.140710024   0.221924131  -0.135956993   0.374797685  -0.022592177 
## A.221562_s_at A.221740_x_at   A.221951_at B.224217_s_at   B.226831_at 
##   0.148270964   0.180185565   0.201092912  -0.036098456   0.143513787 
##   B.227423_at   B.228081_at B.229181_s_at   B.229342_at   B.232398_at 
##  -0.030834163   0.023929217  -0.201146289   0.015535533  -0.169843879 
##   B.233413_at   B.238425_at   B.242255_at 
##   0.150160625  -0.284483809  -0.285904123

2.3 多値のロジスティック回帰

multi.lasso <- function(X, y, lambda) {
  X <- as.matrix(X)
  p <- ncol(X)
  n <- nrow(X)
  K <- length(table(y))
  beta <- matrix(1, nrow = K, ncol = p)
  gamma <- matrix(0, nrow = K, ncol = p)
  while (norm(beta - gamma, "F") > 0.1) {
    gamma <- beta
    for (k in 1:K) {
      r <- 0
      for (h in 1:K)
        if (k != h)
          r <- r + exp(as.vector(X %*% beta[h, ]))
      v <- exp(as.vector(X %*% beta[k, ])) / r
      u <- as.numeric(y == k) - v / (1 + v)
      w <- v / (1 + v) ^ 2
      z <- as.vector(X %*% beta[k, ]) + u / w
      beta[k, ] <- W.linear.lasso(X[, 2:p], z, diag(w), lambda = lambda) 
      print(beta[k, ])
    }
    for (j in 1:p) {
      med <- median(beta[, j])
      for (h in 1:K)
        beta[h, j] <- beta[h, j] - med
    }
  }
  return(beta)
}

例16

df <- iris
x <- matrix(0, 150, 4)
for (j in 1:4)
  x[, j] <- df[[j]]
X <- cbind(1, x)
y <- c(rep(1, 50), rep(2, 50), rep(3, 50))
beta <- multi.lasso(X, y, 0.01)
## [1] 0.3080950 1.2921571 2.0231190 0.0000000 0.6787529
## [1]  7.2969918  1.1262008 -0.6350756  1.0893138 -0.6835181
## [1] -0.4022633  0.7318326  0.8424524  1.3606179  2.2923359
## [1]  2.4562742  0.0000000  1.3406646 -1.6910411 -0.1130906
## [1]  6.231622  0.000000 -1.150196  0.000000 -1.324055
## [1] -1.8822458 -0.8146104 -1.5921661  1.5236968  2.9584280
## [1]  1.4538467  0.0000000  2.6462788 -2.2065299 -0.5212037
## [1]  4.03562210  0.00000000 -0.04287595  0.00000000 -1.25669808
## [1] -7.584966 -1.151576 -1.639842  2.804778  4.575195
## [1]  1.1910878  0.0000000  2.8776746 -2.6542450 -0.6417311
## [1]  2.66120636  0.00000000 -0.01066202  0.00000000 -0.74616319
## [1] -13.763670  -1.464214  -2.468618   4.081814   6.832851
## [1]  1.1048739  0.0000000  3.0692005 -3.0580639 -0.7177671
## [1]  1.47419608  0.00000000  0.00000000  0.00000000 -0.08951308
## [1] -20.565411  -1.732446  -3.210548   5.308382   9.500559
## [1]  1.221106  0.000000  3.184656 -3.441061 -1.380647
## [1] 0.3896127 0.0000000 0.0000000 0.0000000 0.0000000
## [1] -27.363672  -1.904809  -3.977700   6.391882  11.737779
## [1]  2.333029  0.000000  3.190658 -3.847962 -2.057788
## [1] 0.01083126 0.00000000 0.00000000 0.00000000 0.00000000
## [1] -32.274017  -1.995927  -4.635615   7.197584  13.509639
## [1]  4.156703  0.000000  3.068626 -4.299598 -2.588838
## [1] 0.002601294 0.000000000 0.000000000 0.000000000 0.000000000
## [1] -34.511033  -2.035226  -4.965812   7.589258  14.386216
## [1]  6.325169  0.000000  2.822627 -4.825354 -2.881859
## [1] 0.0002078256 0.0000000000 0.0000000000 0.0000000000 0.0000000000
## [1] -35.052617  -2.048477  -5.052837   7.690643  14.604794
## [1]  8.783234  0.000000  2.472210 -5.420466 -2.935649
## [1] 6.881184e-06 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [1] -35.112978  -2.052167  -5.061101   7.704190  14.628298
## [1] 11.602607  0.000000  2.036862 -6.175135 -2.594422
## [1] -6.608283e-05  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [1] -35.119306  -2.052661  -5.062040   7.705757  14.630812
## [1] 14.431959  0.000000  1.633922 -7.082344 -1.851916
## [1] -7.222944e-05  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [1] -35.120009  -2.052719  -5.062148   7.705937  14.631095
## [1] 16.467117  0.000000  1.391694 -7.833493 -1.162274
## [1] -3.886799e-05  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [1] -35.120030  -2.052723  -5.062156   7.705949  14.631118
## [1] 17.4916404  0.0000000  1.3275401 -8.3432360 -0.6005499
## [1] -1.202671e-05  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [1] -35.119990  -2.052720  -5.062154   7.705944  14.631112
## [1] 17.7362987  0.0000000  1.3730213 -8.6046329 -0.2414809
## [1] -3.149099e-06  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [1] -35.119972  -2.052719  -5.062153   7.705941  14.631109
## [1] 17.65919548  0.00000000  1.43419806 -8.69220015 -0.08766604
## [1] -9.567042e-07  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [1] -35.119967  -2.052719  -5.062152   7.705940  14.631108
## [1] 17.55654211  0.00000000  1.47185297 -8.70854021 -0.04348658
## [1] -2.486972e-07  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [1] -35.119966  -2.052719  -5.062152   7.705940  14.631108
## [1] 17.49822846  0.00000000  1.48944145 -8.70864338 -0.03400261
## [1] -5.091597e-08  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## [1] -35.119965  -2.052719  -5.062152   7.705940  14.631108
X %*% t(beta)
##              [,1] [,2]        [,3]
##   [1,]  10.512372    0 -49.5918259
##   [2,]   9.767652    0 -46.6502061
##   [3,]  10.936404    0 -48.0226868
##   [4,]   9.045731    0 -45.7700118
##   [5,]  10.661316    0 -49.8927693
##   [6,]   8.488755    0 -46.9944990
##   [7,]  10.360028    0 -46.5961406
##   [8,]   9.492564    0 -48.1097449
##   [9,]   9.618707    0 -45.1176316
##  [10,]   9.049132    0 -47.8489381
##  [11,]   9.939396    0 -50.4494780
##  [12,]   8.621700    0 -46.9286072
##  [13,]   9.771052    0 -47.9080450
##  [14,]  12.383645    0 -49.1934676
##  [15,]  12.998822    0 -55.1009930
##  [16,]  10.975205    0 -51.6825786
##  [17,]  11.972213    0 -50.0768748
##  [18,]  10.508972    0 -48.1287152
##  [19,]   8.343211    0 -48.5672101
##  [20,]  10.084940    0 -48.8767668
##  [21,]   7.750835    0 -47.3896444
##  [22,]   9.932596    0 -46.9074409
##  [23,]  14.144774    0 -52.1540576
##  [24,]   7.591690    0 -41.8782814
##  [25,]   6.009107    0 -44.6168253
##  [26,]   8.025923    0 -45.3142901
##  [27,]   8.614899    0 -44.4129294
##  [28,]   9.641508    0 -49.0265038
##  [29,]  10.363428    0 -49.2908826
##  [30,]   8.323811    0 -45.7109049
##  [31,]   8.174867    0 -45.4099616
##  [32,]   9.485763    0 -46.0046109
##  [33,]  10.538573    0 -53.5269058
##  [34,]  11.554981    0 -53.9564199
##  [35,]   9.045731    0 -46.3858274
##  [36,]  11.807269    0 -49.4090964
##  [37,]  11.383237    0 -51.1835074
##  [38,]  10.664717    0 -51.1506082
##  [39,]  10.638516    0 -46.3944408
##  [40,]   9.492564    0 -48.3150168
##  [41,]  11.379836    0 -48.6940373
##  [42,]   9.592507    0 -41.5930954
##  [43,]  10.936404    0 -47.4068712
##  [44,]   8.757043    0 -41.9929231
##  [45,]   6.598083    0 -44.3312803
##  [46,]   9.764251    0 -44.9818235
##  [47,]   9.217476    0 -49.5692836
##  [48,]  10.065540    0 -47.0468210
##  [49,]   9.939396    0 -50.2442061
##  [50,]  10.214484    0 -48.3741236
##  [51,] -18.713786    0  -8.9864161
##  [52,] -16.975458    0  -7.8328621
##  [53,] -20.607859    0  -5.2706303
##  [54,] -13.954833    0  -8.2086697
##  [55,] -18.442099    0  -5.2426791
##  [56,] -17.564434    0  -7.2973196
##  [57,] -18.571643    0  -5.1295067
##  [58,]  -7.699638    0 -17.2667436
##  [59,] -18.286354    0  -8.8803877
##  [60,] -12.491592    0  -8.9251981
##  [61,] -10.037143    0 -13.9059667
##  [62,] -14.660753    0  -8.1058542
##  [63,] -14.093576    0 -13.1181460
##  [64,] -19.160619    0  -5.6203237
##  [65,]  -9.577711    0 -14.5336086
##  [66,] -16.250138    0 -10.1761672
##  [67,] -17.273346    0  -5.1782567
##  [68,] -14.219720    0 -14.4680844
##  [69,] -18.464899    0  -2.3601662
##  [70,] -12.779280    0 -13.1231874
##  [71,] -19.598252    0  -0.1053886
##  [72,] -13.210112    0 -11.9713769
##  [73,] -21.501524    0  -1.0017079
##  [74,] -19.302762    0  -8.0403299
##  [75,] -15.673761    0 -10.7816258
##  [76,] -16.399082    0  -9.4646801
##  [77,] -20.180427    0  -5.7804176
##  [78,] -21.634468    0  -0.6570559
##  [79,] -17.422290    0  -5.4931290
##  [80,]  -9.143478    0 -18.3801611
##  [81,] -12.057360    0 -13.1822943
##  [82,] -11.183095    0 -15.4159990
##  [83,] -12.484792    0 -13.0830508
##  [84,] -22.948765    0   1.6059760
##  [85,] -17.273346    0  -4.7677130
##  [86,] -16.680970    0  -6.5610943
##  [87,] -18.866131    0  -6.4012745
##  [88,] -17.438290    0  -6.7684688
##  [89,] -13.783088    0 -11.1868540
##  [90,] -13.656945    0  -9.2211001
##  [91,] -16.988058    0  -8.1080502
##  [92,] -18.140810    0  -6.8971328
##  [93,] -13.504600    0 -11.8062416
##  [94,]  -7.848582    0 -16.9658003
##  [95,] -15.100785    0  -8.8976144
##  [96,] -14.650552    0 -12.0846427
##  [97,] -14.802897    0 -10.1153167
##  [98,] -15.673761    0 -10.3710821
##  [99,]  -4.941501    0 -19.0321737
## [100,] -14.080977    0 -10.3796955
## [101,] -29.923482    0  18.0562115
## [102,] -22.958966    0   6.4058520
## [103,] -29.485849    0  11.3096452
## [104,] -27.011999    0   6.7569213
## [105,] -28.618384    0  13.2337932
## [106,] -35.581899    0  15.6774436
## [107,] -18.024868    0   1.7159439
## [108,] -33.108049    0  10.0983603
## [109,] -29.349504    0   9.5018825
## [110,] -30.347513    0  15.4607130
## [111,] -22.217645    0   3.9009836
## [112,] -24.700694    0   6.7154087
## [113,] -26.002391    0   8.8430850
## [114,] -22.389390    0   8.3160710
## [115,] -22.827023    0  13.2151905
## [116,] -23.969575    0  10.0367756
## [117,] -25.992190    0   5.0695684
## [118,] -35.264610    0  13.6561548
## [119,] -38.797069    0  22.7350360
## [120,] -22.819221    0   1.9033473
## [121,] -27.453032    0  12.0927921
## [122,] -21.071693    0   6.2321033
## [123,] -36.747251    0  15.7920854
## [124,] -21.213837    0   2.3751940
## [125,] -27.297287    0   9.0708992
## [126,] -30.048624    0   6.4732047
## [127,] -20.194028    0   1.3036567
## [128,] -20.767004    0   1.2670921
## [129,] -27.171144    0  11.4471969
## [130,] -28.597983    0   3.0182257
## [131,] -31.518665    0  10.3212265
## [132,] -32.645217    0   8.0076077
## [133,] -27.174544    0  12.9103076
## [134,] -22.796421    0  -0.9791656
## [135,] -27.445230    0   2.8336676
## [136,] -31.234378    0  14.5454235
## [137,] -26.287680    0  13.0045097
## [138,] -25.843246    0   4.7686250
## [139,] -19.896140    0   0.7017700
## [140,] -24.982583    0   7.3610040
## [141,] -26.734512    0  13.7020679
## [142,] -22.376790    0   7.9754436
## [143,] -22.958966    0   6.4058520
## [144,] -29.194761    0  13.8392519
## [145,] -27.310888    0  14.9233422
## [146,] -23.396599    0   9.6627965
## [147,] -22.385990    0   5.6213291
## [148,] -23.386398    0   5.6840080
## [149,] -24.542551    0  10.2054829
## [150,] -22.508733    0   3.2188237

例17

library(glmnet)
df <- iris
x <- as.matrix(df[, 1:4])
y <- as.vector(df[, 5]) 
n <- length(y)
u <- array(dim = n)
for (i in 1:n) {
  if (y[i] == "setosa") {
    u[i] <- 1
  } else if (y[i] == "versicolor") {
    u[i] <- 2
  } else {
    u[i] <- 3
  }
}
u <- as.numeric(u)
cv <- cv.glmnet(x, u, family = "multinomial")
cv2 <- cv.glmnet(x, u, family = "multinomial", type.measure = "class")
par(mfrow = c(1, 2))
plot(cv)
plot(cv2)

par(mfrow = c(1, 1))
lambda <- cv$lambda.min
result <- glmnet(x, y, lambda = lambda, family = "multinomial")
beta <- result$beta
beta.0 <- result$a0
v <- rep(0, n)
for (i in 1:n) {
  max.value <- -Inf
  for (j in 1:3) {
    value <- beta.0[j] + sum(beta[[j]] * x[i, ])
    if (value > max.value) {
      v[i] <- j
      max.value <- value
    }
  }
}
table(u, v)
##    v
## u    1  2  3
##   1 50  0  0
##   2  0 48  2
##   3  0  1 49

2.4 ポアッソン回帰

poisson.lasso <- function(X, y, lambda) {
  beta <- rnorm(p + 1)
  gamma <- rnorm(p + 1)
  while (sum((beta - gamma) ^ 2) > 0.0001) {
    beta <- gamma
    s <- as.vector(X %*% beta)
    w <- exp(s)
    u <- y - w
    z <- s + u / w
    W <- diag(w)
    gamma <- W.linear.lasso(X[, 2:(p + 1)], z, W, lambda)
    print(gamma)
  }
  return(gamma)
}

例18

n <- 10

00
## [1] 0
p <- 3
beta <- rnorm(p + 1)
X <- matrix(rnorm(n * p), ncol = p)
X <- cbind(1, X)
s <- as.vector(X %*% beta)
y <- rpois(n, lambda = exp(s))
beta
## [1] -1.0674998 -0.1966005 -1.1714670  0.9481667
poisson.lasso(X, y, 0.2)
## [1] 0.5148544 0.2274658 0.0000000 0.2304827
## [1] -0.03376344  0.10647382  0.00000000  0.03611964
## [1] -0.44375007  0.07246886 -0.20635897  0.00000000
## [1] -0.6729065  0.1041351 -0.3866383  0.0000000
## [1] -0.7514708  0.1366967 -0.4397856  0.0000000
## [1] -0.7625520  0.1455526 -0.4472076  0.0000000
## [1] -0.7629401  0.1460133 -0.4474231  0.0000000
## [1] -0.7629401  0.1460133 -0.4474231  0.0000000

例19

library(glmnet)
library(MASS)
data(birthwt)
df <- birthwt[, -1]
dy <- df[, 8]
dx <- data.matrix(df[, -8])
cvfit <- cv.glmnet(x = dx, y = dy, family = "poisson", standardize = TRUE)
coef(cvfit, s = "lambda.min")
## 9 x 1 sparse Matrix of class "dgCMatrix"
##                         1
## (Intercept) -0.8603852995
## age          0.0243126476
## lwt          0.0004265947
## race         .           
## smoke        .           
## ptl          .           
## ht           .           
## ui           .           
## bwt          .

2.5 生存時間解析

library(survival)
data(kidney)
kidney
##    id time status age sex disease frail
## 1   1    8      1  28   1   Other   2.3
## 2   1   16      1  28   1   Other   2.3
## 3   2   23      1  48   2      GN   1.9
## 4   2   13      0  48   2      GN   1.9
## 5   3   22      1  32   1   Other   1.2
## 6   3   28      1  32   1   Other   1.2
## 7   4  447      1  31   2   Other   0.5
## 8   4  318      1  32   2   Other   0.5
## 9   5   30      1  10   1   Other   1.5
## 10  5   12      1  10   1   Other   1.5
## 11  6   24      1  16   2   Other   1.1
## 12  6  245      1  17   2   Other   1.1
## 13  7    7      1  51   1      GN   3.0
## 14  7    9      1  51   1      GN   3.0
## 15  8  511      1  55   2      GN   0.5
## 16  8   30      1  56   2      GN   0.5
## 17  9   53      1  69   2      AN   0.7
## 18  9  196      1  69   2      AN   0.7
## 19 10   15      1  51   1      GN   0.4
## 20 10  154      1  52   1      GN   0.4
## 21 11    7      1  44   2      AN   0.6
## 22 11  333      1  44   2      AN   0.6
## 23 12  141      1  34   2   Other   1.2
## 24 12    8      0  34   2   Other   1.2
## 25 13   96      1  35   2      AN   1.4
## 26 13   38      1  35   2      AN   1.4
## 27 14  149      0  42   2      AN   0.4
## 28 14   70      0  42   2      AN   0.4
## 29 15  536      1  17   2   Other   0.4
## 30 15   25      0  17   2   Other   0.4
## 31 16   17      1  60   1      AN   1.1
## 32 16    4      0  60   1      AN   1.1
## 33 17  185      1  60   2   Other   0.8
## 34 17  177      1  60   2   Other   0.8
## 35 18  292      1  43   2   Other   0.8
## 36 18  114      1  44   2   Other   0.8
## 37 19   22      0  53   2      GN   0.5
## 38 19  159      0  53   2      GN   0.5
## 39 20   15      1  44   2   Other   1.3
## 40 20  108      0  44   2   Other   1.3
## 41 21  152      1  46   1     PKD   0.2
## 42 21  562      1  47   1     PKD   0.2
## 43 22  402      1  30   2   Other   0.6
## 44 22   24      0  30   2   Other   0.6
## 45 23   13      1  62   2      AN   1.7
## 46 23   66      1  63   2      AN   1.7
## 47 24   39      1  42   2      AN   1.0
## 48 24   46      0  43   2      AN   1.0
## 49 25   12      1  43   1      AN   0.7
## 50 25   40      1  43   1      AN   0.7
## 51 26  113      0  57   2      AN   0.5
## 52 26  201      1  58   2      AN   0.5
## 53 27  132      1  10   2      GN   1.1
## 54 27  156      1  10   2      GN   1.1
## 55 28   34      1  52   2      AN   1.8
## 56 28   30      1  52   2      AN   1.8
## 57 29    2      1  53   1      GN   1.5
## 58 29   25      1  53   1      GN   1.5
## 59 30  130      1  54   2      GN   1.5
## 60 30   26      1  54   2      GN   1.5
## 61 31   27      1  56   2      AN   1.7
## 62 31   58      1  56   2      AN   1.7
## 63 32    5      0  50   2      AN   1.3
## 64 32   43      1  51   2      AN   1.3
## 65 33  152      1  57   2     PKD   2.9
## 66 33   30      1  57   2     PKD   2.9
## 67 34  190      1  44   2      GN   0.7
## 68 34    5      0  45   2      GN   0.7
## 69 35  119      1  22   2   Other   2.2
## 70 35    8      1  22   2   Other   2.2
## 71 36   54      0  42   2   Other   0.7
## 72 36   16      0  42   2   Other   0.7
## 73 37    6      0  52   2     PKD   2.1
## 74 37   78      1  52   2     PKD   2.1
## 75 38   63      1  60   1     PKD   1.2
## 76 38    8      0  60   1     PKD   1.2
y <- kidney$time
delta <- kidney$status
Surv(y, delta)
##  [1]   8   16   23   13+  22   28  447  318   30   12   24  245    7    9  511 
## [16]  30   53  196   15  154    7  333  141    8+  96   38  149+  70+ 536   25+
## [31]  17    4+ 185  177  292  114   22+ 159+  15  108+ 152  562  402   24+  13 
## [46]  66   39   46+  12   40  113+ 201  132  156   34   30    2   25  130   26 
## [61]  27   58    5+  43  152   30  190    5+ 119    8   54+  16+   6+  78   63 
## [76]   8+

例21

fit <- survfit(Surv(time, status) ~ disease, data = kidney)
plot(fit, xlab = "時間", ylab = "生存率", col = c("red", "green", "blue", "black"))
legend(300, 0.8, legend = c("その他", "GN", "AN", "PKD"), 
       lty = 1, col = c("red", "green", "blue", "black"))

cox.lasso <- function(X, y, delta, lambda = lambda) {
  delta[1] <- 1
  n <- length(y)
  w <- array(dim = n)
  u <- array(dim = n)
  pi <- array(dim = c(n, n))
  beta <- rnorm(p)
  gamma <- rep(0, p)
  while (sum((beta - gamma) ^ 2) > 10 ^ {-4}) {
    beta <- gamma
    s <- as.vector(X %*% beta)
    v <- exp(s)
    for (i in 1:n)
      for (j in 1:n)
        pi[i, j] <- v[i] / sum(v[j:n])
    for (i in 1:n) {
      u[i] <- delta[i]
      w[i] <- 0
      for (j in 1:i) {
        if (delta[j] == 1) {
          u[i] <- u[i] - pi[i, j]
          w[i] <- w[i] + pi[i, j] * (1 - pi[i, j])
        }
      }
    }
    z <- s + u / w
    W <- diag(w)
    print(gamma)
    gamma <- W.linear.lasso(X, z, W, lambda = lambda)[-1]
  }
  return(gamma)
}

例22

df <- kidney
index <- order(df$time)
df <- df[index, ]
n <- nrow(df)
p <- 4
y <- as.numeric(df[[2]])
delta <- as.numeric(df[[3]])
X <- as.numeric(df[[4]])
for (j in 5:7)
  X <- cbind(X, as.numeric(df[[j]]))
z <- Surv(y, delta)
cox.lasso(X, y, delta, 0)
## [1] 0 0 0 0
## [1]  0.0101287 -1.7747758 -0.3887608  1.3532378
## [1]  0.01462571 -1.69299527 -0.41598742  1.38980788
## [1]  0.01591941 -1.66769665 -0.42331475  1.40330234
## [1]  0.01628935 -1.66060178 -0.42528537  1.40862969
cox.lasso(X, y, delta, 0.1)
## [1] 0 0 0 0
## [1]  0.00000000 -1.04501150 -0.08845692  1.00600836
## [1]  0.00000000 -0.97761575 -0.05957538  0.97284124
## [1]  0.0000000 -0.9567328 -0.0530083  0.9592735
## [1]  0.00000000 -0.95074758 -0.05159338  0.95420530
cox.lasso(X, y, delta, 0.2)
## [1] 0 0 0 0
## [1]  0.0000000 -0.5312546  0.0000000  0.7201078
## [1]  0.0000000 -0.5090856  0.0000000  0.6857966
## [1]  0.0000000 -0.5048733  0.0000000  0.6768827
glmnet(X, z, family = "cox", lambda = 0.1)$beta
## 4 x 1 sparse Matrix of class "dgCMatrix"
##            s0
## X  .         
##   -0.87359015
##   -0.05659599
##    0.92923820

例23

library(survival)
load("LymphomaData.rda")
attach("LymphomaData.rda")
## The following object is masked _by_ .GlobalEnv:
## 
##     patient.data
names(patient.data)
## [1] "x"      "time"   "status"
x <- t(patient.data$x)
y <- patient.data$time
delta <- patient.data$status
Surv(y, delta)
##   [1]  5.0+  5.9+  6.6+ 13.1+  1.6   1.3   1.4   2.2   3.4   2.2   9.8+  9.1+
##  [13]  6.2+  9.1+  3.8   5.5   5.3+  3.5   2.3   1.1   2.7   8.2+  1.6   1.0 
##  [25]  1.3   1.3   3.5  11.8+  8.7+  6.9+ 17.9  10.7+ 11.1+  2.6+  1.8   4.9 
##  [37]  4.3   8.1+  4.3+  7.8+  1.0   1.1   3.7  22.8+  1.8   7.2   8.2+  1.4 
##  [49]  8.2+  7.0+  2.0   2.7+  1.6   3.3+  6.4   5.1  15.4+ 15.2+  2.1   1.7 
##  [61]  5.0+  2.0   2.7+  3.0+  1.4   2.9+ 10.7+  1.1   8.4+  6.0+ 11.0+  1.3 
##  [73]  2.3   6.5+  4.4+  7.9   2.9   1.3  11.5+  1.4   1.3   3.6   1.3   2.3 
##  [85]  2.5   1.5   3.4   1.9   3.8   1.1   3.3+  1.0   1.2  13.2+  4.3+  7.8 
##  [97]  3.5+  2.4   2.0   8.6+  2.6  10.2+  5.2+  2.1   6.9+  1.7   1.7   4.9+
## [109]  1.3   3.8+ 14.3+  9.4+  2.0  11.3+  6.6+ 14.8+ 17.8+  1.9   6.6   1.8+
## [121]  3.7+  3.8+  2.0+  5.8+  7.7+  1.2  10.1+  1.7   1.3   1.3   7.7+  6.8+
## [133]  8.1+  3.3   7.5+ 10.7+  1.0  15.6+  3.9   7.6   3.3  12.6+  1.2   1.7 
## [145]  3.1   3.3   3.3   2.3   1.1   1.5   2.0  13.3+  8.8+  3.1   1.4   3.0 
## [157]  7.5+  2.0   8.3+  7.0  12.0  11.6   2.0   6.4   1.8  11.5+ 10.6+ 10.0+
## [169]  8.4+  8.5  12.3+  3.0   1.4  12.4+  4.0   9.4+  1.2   6.0+  1.9   1.6 
## [181]  1.6   7.4+  5.8+  6.0   1.3  10.5+  5.1+  9.9+  2.5   2.1   4.9+  1.1 
## [193]  2.3  16.4+ 10.1   5.3   1.1   4.6   1.4  11.2+ 11.4+  1.7   1.0   1.7 
## [205]  1.1   5.1+  4.5+  1.5  12.7+  7.1   2.1   2.9  10.5+  1.4   2.3   2.1 
## [217]  1.7   2.6   4.4+ 10.2+  2.2  12.2+  2.4   9.8+  4.0   1.4   3.9   1.4 
## [229] 18.4+ 17.8+  2.3   5.0+  6.6  20.8+  7.9   1.3   4.1   8.7+  3.7   2.1
library(ranger)
## Warning: package 'ranger' was built under R version 4.0.3
library(ggplot2)
## Warning: package 'ggplot2' was built under R version 4.0.3
library(dplyr)
## Warning: package 'dplyr' was built under R version 4.0.3
## 
## Attaching package: 'dplyr'
## The following object is masked from 'package:MASS':
## 
##     select
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(ggfortify)
## Warning: package 'ggfortify' was built under R version 4.0.3
cv.fit <- cv.glmnet(x, Surv(y, delta), family = "cox")
fit2 <- glmnet(x, Surv(y, delta), lambda = cv.fit$lambda.min, family = "cox")
z <- sign(drop(x %*% fit2$beta))
fit3 <- survfit(Surv(y, delta) ~ z)
autoplot(fit3)

mean(y[z == 1])
## [1] 3.146429
mean(y[z == -1])
## [1] 7.392188