soft.th <- function(lambda, x) {
return(sign(x) * pmax(abs(x) - lambda, 0))
}
linear.lasso <- function(X, y, lambda = 0, beta = rep(0, ncol(X))) {
n <- nrow(X)
p <- ncol(X)
res <- centralize(X, y) ## 中心化(下記参照)
X <- res$X
y <- res$y
eps <- 1
beta.old <- beta
while (eps > 0.001) { ## このループの収束を待つ
for (j in 1:p) {
r <- y - as.matrix(X[, -j]) %*% beta[-j]
beta[j] <- soft.th(lambda, sum(r * X[, j]) / n) / (sum(X[, j] * X[, j]) / n)
}
eps <- max(abs(beta - beta.old))
beta.old <- beta
}
beta <- beta / res$X.sd ## 各変数の係数を正規化前のものに戻す
beta.0 <- res$y.bar - sum(res$X.bar * beta)
return(list(beta = beta, beta.0 = beta.0))
}
centralize <- function(X, y, standardize = TRUE) {
X <- as.matrix(X)
n <- nrow(X)
p <- ncol(X)
X.bar <- array(dim = p) ## Xの各列の平均
X.sd <- array(dim = p) ## Xの各列の標準偏差
for (j in 1:p) {
X.bar[j] <- mean(X[, j])
X[, j] <- (X[, j] - X.bar[j]) ## Xの各列の中心化
X.sd[j] <- sqrt(var(X[, j]))
if (standardize == TRUE)
X[, j] <- X[, j] / X.sd[j] ## Xの各列の標準化
}
if (class(y) == "matrix") { ## yが行列の場合
K <- ncol(y)
y.bar <- array(dim = K) ## yの平均
for (k in 1:K) {
y.bar[k] <- mean(y[, k])
y[, k] <- y[, k] - y.bar[k] ## yの中心化
}
} else { ## yがベクトルの場合
y.bar <- mean(y)
y <- y - y.bar
}
return(list(X = X, y = y, X.bar = X.bar, X.sd = X.sd, y.bar = y.bar))
}
W.linear.lasso <- function(X, y, W, lambda = 0) {
n <- nrow(X)
p <- ncol(X)
X.bar <- array(dim = p)
for (k in 1:p) {
X.bar[k] <- sum(W %*% X[, k]) / sum(W)
X[, k] <- X[, k] - X.bar[k]
}
y.bar <- sum(W %*% y) / sum(W)
y <- y - y.bar
L <- chol(W)
# L <- sqrt(W)
u <- as.vector(L %*% y)
V <- L %*% X
beta <- linear.lasso(V, u, lambda)$beta
beta.0 <- y.bar - sum(X.bar * beta)
return(c(beta.0, beta))
}
f <- function(x) {
return(exp(beta.0 + beta * x) / (1 + exp(beta.0 + beta * x)))
}
beta.0 <- 0
beta.seq <- c(0, 0.2, 0.5, 1, 2, 10)
m <- length(beta.seq)
beta <- beta.seq[1]
plot(f, xlim = c(-10, 10), ylim = c(0, 1), xlab = "x", ylab = "y",
col = 1, main = "ロジスティック曲線")
for (i in 2:m) {
beta <- beta.seq[i]
par(new = TRUE)
plot(f, xlim = c(-10, 10), ylim = c(0, 1), xlab = "", ylab = "", axes = FALSE, col = i)
}
legend("topleft", legend = beta.seq, col = 1:length(beta.seq), lwd = 2, cex = .8)
par(new = FALSE)
## データ生成
N <- 1000
p <- 2
X <- matrix(rnorm(N * p), ncol = p)
X <- cbind(rep(1, N), X)
beta <- rnorm(p + 1)
y <- array(N)
s <- as.vector(X %*% beta)
prob <- 1 / (1 + exp(s))
for (i in 1:N) {
if (runif(1) > prob[i]) {
y[i] <- 1
} else {
y[i] <- -1
}
}
beta
## [1] -0.8251151 -0.6041347 2.8357392
## 最尤推定値の計算
beta <- Inf
gamma <- rnorm(p + 1)
while (sum((beta - gamma) ^ 2) > 0.001) {
beta <- gamma
s <- as.vector(X %*% beta)
v <- exp(-s * y)
u <- y * v / (1 + v)
w <- v / (1 + v) ^ 2
z <- s + u / w
W <- diag(w)
gamma <- as.vector(solve(t(X) %*% W %*% X) %*% t(X) %*% W %*% z) ##
print(gamma)
}
## [1] -1.09048865 -0.01923956 2.11691931
## [1] -0.7567999 -0.5396643 2.5134395
## [1] -0.7908441 -0.6295797 2.8209931
## [1] -0.7982177 -0.6412975 2.8641928
## [1] -0.7983656 -0.6414807 2.8649043
beta ## 真の値。最尤法でこの値を推定したい
## [1] -0.7982177 -0.6412975 2.8641928
logistic.lasso <- function(X, y, lambda) {
p <- ncol(X)
beta <- Inf
gamma <- rnorm(p)
while (sum((beta - gamma) ^ 2) > 0.01) {
beta <- gamma
s <- as.vector(X %*% beta)
v <- as.vector(exp(-s * y))
u <- y * v / (1 + v)
w <- v / (1 + v) ^ 2
z <- s + u / w
W <- diag(w)
gamma <- W.linear.lasso(X[, 2:p], z, W, lambda = lambda)
print(gamma)
}
return(gamma)
}
N <- 1000
p <- 2
X <- matrix(rnorm(N * p), ncol = p)
X <- cbind(rep(1, N), X)
beta <- rnorm(p + 1)
y <- array(N)
s <- as.vector(X %*% beta)
prob <- 1 / (1 + exp(s))
for (i in 1:N) {
if (runif(1) > prob[i]) {
y[i] <- 1
} else {
y[i] <- -1
}
}
#logistic.lasso(X, y, 0)
logistic.lasso(X, y, 0.1)
## [1] -0.1179388 1.7995542 -0.7891112
## [1] -0.1239010 1.3534193 0.2157502
## [1] -0.13270237 1.52965899 0.01516135
## [1] -0.136756435 1.584675126 0.009576722
## [1] -0.136756435 1.584675126 0.009576722
logistic.lasso(X, y, 0.2)
## [1] 0.1158749 1.3300292 0.0000000
## [1] -0.1216319 1.2138788 0.0000000
## [1] -0.1118214 1.2051842 0.0000000
## [1] -0.1118214 1.2051842 0.0000000
## データ生成
N <- 1000
p <- 2
X <- matrix(rnorm(N * p), ncol = p)
X <- cbind(rep(1, N), X)
beta <- 10 * rnorm(p + 1)
y <- array(N)
s <- as.vector(X %*% beta)
prob <- 1 / (1 + exp(s))
for (i in 1:N) {
if (runif(1) > prob[i]) {
y[i] <- 1
} else {
y[i] <- -1
}
}
## パラメータ推定
beta.est <- logistic.lasso(X, y, 0.1)
## [1] 1.001715 -1.602514 0.000000
## [1] 1.3120540 -2.3016201 -0.2229933
## [1] 1.5723713 -2.8611635 -0.3413522
## [1] 1.7273824 -3.1952329 -0.4064963
## [1] 1.7918640 -3.3358622 -0.4353427
## [1] 1.8119803 -3.3802578 -0.4451937
## 分類処理
for (i in 1:N) {
if (runif(1) > prob[i]) {
y[i] <- 1
} else {
y[i] <- -1
}
}
z <- sign(X %*% beta.est) ## 指数部が正なら+1, 負なら-1と判定する
table(y, z)
## z
## y -1 1
## -1 266 62
## 1 37 635
library(glmnet)
## Warning: package 'glmnet' was built under R version 4.0.3
## Loading required package: Matrix
## Loaded glmnet 4.0-2
df <- read.csv("breastcancer.csv")
## ファイル breastcancer.csv をカレントディレクトリにおく
x <- as.matrix(df[, 1:1000])
y <- as.vector(df[, 1001])
cv <- cv.glmnet(x, y, family = "binomial")
cv2 <- cv.glmnet(x, y, family = "binomial", type.measure = "class")
par(mfrow = c(1, 2))
plot(cv)
plot(cv2)
par(mfrow = c(1, 1))
glm <- glmnet(x, y, lambda = 0.03, family = "binomial")
beta <- drop(glm$beta)
beta[beta != 0]
## A.200053_at A.200740_s_at A.200855_at A.202011_at A.202117_at
## 0.253816593 -0.009995197 0.127286134 0.050447833 0.027646245
## A.203188_at A.203287_at A.203347_s_at A.206197_at A.208184_s_at
## 0.022838365 -0.044863822 -0.258511276 0.044533711 -0.161215317
## A.209608_s_at A.210137_s_at A.212708_at A.213285_at A.216515_x_at
## -0.191407462 0.209509915 0.033927033 0.016215819 -0.014640332
## A.218080_x_at A.218795_at A.218877_s_at A.219252_s_at A.219490_s_at
## -0.140710024 0.221924131 -0.135956993 0.374797685 -0.022592177
## A.221562_s_at A.221740_x_at A.221951_at B.224217_s_at B.226831_at
## 0.148270964 0.180185565 0.201092912 -0.036098456 0.143513787
## B.227423_at B.228081_at B.229181_s_at B.229342_at B.232398_at
## -0.030834163 0.023929217 -0.201146289 0.015535533 -0.169843879
## B.233413_at B.238425_at B.242255_at
## 0.150160625 -0.284483809 -0.285904123
multi.lasso <- function(X, y, lambda) {
X <- as.matrix(X)
p <- ncol(X)
n <- nrow(X)
K <- length(table(y))
beta <- matrix(1, nrow = K, ncol = p)
gamma <- matrix(0, nrow = K, ncol = p)
while (norm(beta - gamma, "F") > 0.1) {
gamma <- beta
for (k in 1:K) {
r <- 0
for (h in 1:K)
if (k != h)
r <- r + exp(as.vector(X %*% beta[h, ]))
v <- exp(as.vector(X %*% beta[k, ])) / r
u <- as.numeric(y == k) - v / (1 + v)
w <- v / (1 + v) ^ 2
z <- as.vector(X %*% beta[k, ]) + u / w
beta[k, ] <- W.linear.lasso(X[, 2:p], z, diag(w), lambda = lambda)
print(beta[k, ])
}
for (j in 1:p) {
med <- median(beta[, j])
for (h in 1:K)
beta[h, j] <- beta[h, j] - med
}
}
return(beta)
}
df <- iris
x <- matrix(0, 150, 4)
for (j in 1:4)
x[, j] <- df[[j]]
X <- cbind(1, x)
y <- c(rep(1, 50), rep(2, 50), rep(3, 50))
beta <- multi.lasso(X, y, 0.01)
## [1] 0.3080950 1.2921571 2.0231190 0.0000000 0.6787529
## [1] 7.2969918 1.1262008 -0.6350756 1.0893138 -0.6835181
## [1] -0.4022633 0.7318326 0.8424524 1.3606179 2.2923359
## [1] 2.4562742 0.0000000 1.3406646 -1.6910411 -0.1130906
## [1] 6.231622 0.000000 -1.150196 0.000000 -1.324055
## [1] -1.8822458 -0.8146104 -1.5921661 1.5236968 2.9584280
## [1] 1.4538467 0.0000000 2.6462788 -2.2065299 -0.5212037
## [1] 4.03562210 0.00000000 -0.04287595 0.00000000 -1.25669808
## [1] -7.584966 -1.151576 -1.639842 2.804778 4.575195
## [1] 1.1910878 0.0000000 2.8776746 -2.6542450 -0.6417311
## [1] 2.66120636 0.00000000 -0.01066202 0.00000000 -0.74616319
## [1] -13.763670 -1.464214 -2.468618 4.081814 6.832851
## [1] 1.1048739 0.0000000 3.0692005 -3.0580639 -0.7177671
## [1] 1.47419608 0.00000000 0.00000000 0.00000000 -0.08951308
## [1] -20.565411 -1.732446 -3.210548 5.308382 9.500559
## [1] 1.221106 0.000000 3.184656 -3.441061 -1.380647
## [1] 0.3896127 0.0000000 0.0000000 0.0000000 0.0000000
## [1] -27.363672 -1.904809 -3.977700 6.391882 11.737779
## [1] 2.333029 0.000000 3.190658 -3.847962 -2.057788
## [1] 0.01083126 0.00000000 0.00000000 0.00000000 0.00000000
## [1] -32.274017 -1.995927 -4.635615 7.197584 13.509639
## [1] 4.156703 0.000000 3.068626 -4.299598 -2.588838
## [1] 0.002601294 0.000000000 0.000000000 0.000000000 0.000000000
## [1] -34.511033 -2.035226 -4.965812 7.589258 14.386216
## [1] 6.325169 0.000000 2.822627 -4.825354 -2.881859
## [1] 0.0002078256 0.0000000000 0.0000000000 0.0000000000 0.0000000000
## [1] -35.052617 -2.048477 -5.052837 7.690643 14.604794
## [1] 8.783234 0.000000 2.472210 -5.420466 -2.935649
## [1] 6.881184e-06 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [1] -35.112978 -2.052167 -5.061101 7.704190 14.628298
## [1] 11.602607 0.000000 2.036862 -6.175135 -2.594422
## [1] -6.608283e-05 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [1] -35.119306 -2.052661 -5.062040 7.705757 14.630812
## [1] 14.431959 0.000000 1.633922 -7.082344 -1.851916
## [1] -7.222944e-05 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [1] -35.120009 -2.052719 -5.062148 7.705937 14.631095
## [1] 16.467117 0.000000 1.391694 -7.833493 -1.162274
## [1] -3.886799e-05 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [1] -35.120030 -2.052723 -5.062156 7.705949 14.631118
## [1] 17.4916404 0.0000000 1.3275401 -8.3432360 -0.6005499
## [1] -1.202671e-05 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [1] -35.119990 -2.052720 -5.062154 7.705944 14.631112
## [1] 17.7362987 0.0000000 1.3730213 -8.6046329 -0.2414809
## [1] -3.149099e-06 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [1] -35.119972 -2.052719 -5.062153 7.705941 14.631109
## [1] 17.65919548 0.00000000 1.43419806 -8.69220015 -0.08766604
## [1] -9.567042e-07 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [1] -35.119967 -2.052719 -5.062152 7.705940 14.631108
## [1] 17.55654211 0.00000000 1.47185297 -8.70854021 -0.04348658
## [1] -2.486972e-07 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [1] -35.119966 -2.052719 -5.062152 7.705940 14.631108
## [1] 17.49822846 0.00000000 1.48944145 -8.70864338 -0.03400261
## [1] -5.091597e-08 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00
## [1] -35.119965 -2.052719 -5.062152 7.705940 14.631108
X %*% t(beta)
## [,1] [,2] [,3]
## [1,] 10.512372 0 -49.5918259
## [2,] 9.767652 0 -46.6502061
## [3,] 10.936404 0 -48.0226868
## [4,] 9.045731 0 -45.7700118
## [5,] 10.661316 0 -49.8927693
## [6,] 8.488755 0 -46.9944990
## [7,] 10.360028 0 -46.5961406
## [8,] 9.492564 0 -48.1097449
## [9,] 9.618707 0 -45.1176316
## [10,] 9.049132 0 -47.8489381
## [11,] 9.939396 0 -50.4494780
## [12,] 8.621700 0 -46.9286072
## [13,] 9.771052 0 -47.9080450
## [14,] 12.383645 0 -49.1934676
## [15,] 12.998822 0 -55.1009930
## [16,] 10.975205 0 -51.6825786
## [17,] 11.972213 0 -50.0768748
## [18,] 10.508972 0 -48.1287152
## [19,] 8.343211 0 -48.5672101
## [20,] 10.084940 0 -48.8767668
## [21,] 7.750835 0 -47.3896444
## [22,] 9.932596 0 -46.9074409
## [23,] 14.144774 0 -52.1540576
## [24,] 7.591690 0 -41.8782814
## [25,] 6.009107 0 -44.6168253
## [26,] 8.025923 0 -45.3142901
## [27,] 8.614899 0 -44.4129294
## [28,] 9.641508 0 -49.0265038
## [29,] 10.363428 0 -49.2908826
## [30,] 8.323811 0 -45.7109049
## [31,] 8.174867 0 -45.4099616
## [32,] 9.485763 0 -46.0046109
## [33,] 10.538573 0 -53.5269058
## [34,] 11.554981 0 -53.9564199
## [35,] 9.045731 0 -46.3858274
## [36,] 11.807269 0 -49.4090964
## [37,] 11.383237 0 -51.1835074
## [38,] 10.664717 0 -51.1506082
## [39,] 10.638516 0 -46.3944408
## [40,] 9.492564 0 -48.3150168
## [41,] 11.379836 0 -48.6940373
## [42,] 9.592507 0 -41.5930954
## [43,] 10.936404 0 -47.4068712
## [44,] 8.757043 0 -41.9929231
## [45,] 6.598083 0 -44.3312803
## [46,] 9.764251 0 -44.9818235
## [47,] 9.217476 0 -49.5692836
## [48,] 10.065540 0 -47.0468210
## [49,] 9.939396 0 -50.2442061
## [50,] 10.214484 0 -48.3741236
## [51,] -18.713786 0 -8.9864161
## [52,] -16.975458 0 -7.8328621
## [53,] -20.607859 0 -5.2706303
## [54,] -13.954833 0 -8.2086697
## [55,] -18.442099 0 -5.2426791
## [56,] -17.564434 0 -7.2973196
## [57,] -18.571643 0 -5.1295067
## [58,] -7.699638 0 -17.2667436
## [59,] -18.286354 0 -8.8803877
## [60,] -12.491592 0 -8.9251981
## [61,] -10.037143 0 -13.9059667
## [62,] -14.660753 0 -8.1058542
## [63,] -14.093576 0 -13.1181460
## [64,] -19.160619 0 -5.6203237
## [65,] -9.577711 0 -14.5336086
## [66,] -16.250138 0 -10.1761672
## [67,] -17.273346 0 -5.1782567
## [68,] -14.219720 0 -14.4680844
## [69,] -18.464899 0 -2.3601662
## [70,] -12.779280 0 -13.1231874
## [71,] -19.598252 0 -0.1053886
## [72,] -13.210112 0 -11.9713769
## [73,] -21.501524 0 -1.0017079
## [74,] -19.302762 0 -8.0403299
## [75,] -15.673761 0 -10.7816258
## [76,] -16.399082 0 -9.4646801
## [77,] -20.180427 0 -5.7804176
## [78,] -21.634468 0 -0.6570559
## [79,] -17.422290 0 -5.4931290
## [80,] -9.143478 0 -18.3801611
## [81,] -12.057360 0 -13.1822943
## [82,] -11.183095 0 -15.4159990
## [83,] -12.484792 0 -13.0830508
## [84,] -22.948765 0 1.6059760
## [85,] -17.273346 0 -4.7677130
## [86,] -16.680970 0 -6.5610943
## [87,] -18.866131 0 -6.4012745
## [88,] -17.438290 0 -6.7684688
## [89,] -13.783088 0 -11.1868540
## [90,] -13.656945 0 -9.2211001
## [91,] -16.988058 0 -8.1080502
## [92,] -18.140810 0 -6.8971328
## [93,] -13.504600 0 -11.8062416
## [94,] -7.848582 0 -16.9658003
## [95,] -15.100785 0 -8.8976144
## [96,] -14.650552 0 -12.0846427
## [97,] -14.802897 0 -10.1153167
## [98,] -15.673761 0 -10.3710821
## [99,] -4.941501 0 -19.0321737
## [100,] -14.080977 0 -10.3796955
## [101,] -29.923482 0 18.0562115
## [102,] -22.958966 0 6.4058520
## [103,] -29.485849 0 11.3096452
## [104,] -27.011999 0 6.7569213
## [105,] -28.618384 0 13.2337932
## [106,] -35.581899 0 15.6774436
## [107,] -18.024868 0 1.7159439
## [108,] -33.108049 0 10.0983603
## [109,] -29.349504 0 9.5018825
## [110,] -30.347513 0 15.4607130
## [111,] -22.217645 0 3.9009836
## [112,] -24.700694 0 6.7154087
## [113,] -26.002391 0 8.8430850
## [114,] -22.389390 0 8.3160710
## [115,] -22.827023 0 13.2151905
## [116,] -23.969575 0 10.0367756
## [117,] -25.992190 0 5.0695684
## [118,] -35.264610 0 13.6561548
## [119,] -38.797069 0 22.7350360
## [120,] -22.819221 0 1.9033473
## [121,] -27.453032 0 12.0927921
## [122,] -21.071693 0 6.2321033
## [123,] -36.747251 0 15.7920854
## [124,] -21.213837 0 2.3751940
## [125,] -27.297287 0 9.0708992
## [126,] -30.048624 0 6.4732047
## [127,] -20.194028 0 1.3036567
## [128,] -20.767004 0 1.2670921
## [129,] -27.171144 0 11.4471969
## [130,] -28.597983 0 3.0182257
## [131,] -31.518665 0 10.3212265
## [132,] -32.645217 0 8.0076077
## [133,] -27.174544 0 12.9103076
## [134,] -22.796421 0 -0.9791656
## [135,] -27.445230 0 2.8336676
## [136,] -31.234378 0 14.5454235
## [137,] -26.287680 0 13.0045097
## [138,] -25.843246 0 4.7686250
## [139,] -19.896140 0 0.7017700
## [140,] -24.982583 0 7.3610040
## [141,] -26.734512 0 13.7020679
## [142,] -22.376790 0 7.9754436
## [143,] -22.958966 0 6.4058520
## [144,] -29.194761 0 13.8392519
## [145,] -27.310888 0 14.9233422
## [146,] -23.396599 0 9.6627965
## [147,] -22.385990 0 5.6213291
## [148,] -23.386398 0 5.6840080
## [149,] -24.542551 0 10.2054829
## [150,] -22.508733 0 3.2188237
library(glmnet)
df <- iris
x <- as.matrix(df[, 1:4])
y <- as.vector(df[, 5])
n <- length(y)
u <- array(dim = n)
for (i in 1:n) {
if (y[i] == "setosa") {
u[i] <- 1
} else if (y[i] == "versicolor") {
u[i] <- 2
} else {
u[i] <- 3
}
}
u <- as.numeric(u)
cv <- cv.glmnet(x, u, family = "multinomial")
cv2 <- cv.glmnet(x, u, family = "multinomial", type.measure = "class")
par(mfrow = c(1, 2))
plot(cv)
plot(cv2)
par(mfrow = c(1, 1))
lambda <- cv$lambda.min
result <- glmnet(x, y, lambda = lambda, family = "multinomial")
beta <- result$beta
beta.0 <- result$a0
v <- rep(0, n)
for (i in 1:n) {
max.value <- -Inf
for (j in 1:3) {
value <- beta.0[j] + sum(beta[[j]] * x[i, ])
if (value > max.value) {
v[i] <- j
max.value <- value
}
}
}
table(u, v)
## v
## u 1 2 3
## 1 50 0 0
## 2 0 48 2
## 3 0 1 49
poisson.lasso <- function(X, y, lambda) {
beta <- rnorm(p + 1)
gamma <- rnorm(p + 1)
while (sum((beta - gamma) ^ 2) > 0.0001) {
beta <- gamma
s <- as.vector(X %*% beta)
w <- exp(s)
u <- y - w
z <- s + u / w
W <- diag(w)
gamma <- W.linear.lasso(X[, 2:(p + 1)], z, W, lambda)
print(gamma)
}
return(gamma)
}
n <- 10
00
## [1] 0
p <- 3
beta <- rnorm(p + 1)
X <- matrix(rnorm(n * p), ncol = p)
X <- cbind(1, X)
s <- as.vector(X %*% beta)
y <- rpois(n, lambda = exp(s))
beta
## [1] -1.0674998 -0.1966005 -1.1714670 0.9481667
poisson.lasso(X, y, 0.2)
## [1] 0.5148544 0.2274658 0.0000000 0.2304827
## [1] -0.03376344 0.10647382 0.00000000 0.03611964
## [1] -0.44375007 0.07246886 -0.20635897 0.00000000
## [1] -0.6729065 0.1041351 -0.3866383 0.0000000
## [1] -0.7514708 0.1366967 -0.4397856 0.0000000
## [1] -0.7625520 0.1455526 -0.4472076 0.0000000
## [1] -0.7629401 0.1460133 -0.4474231 0.0000000
## [1] -0.7629401 0.1460133 -0.4474231 0.0000000
library(glmnet)
library(MASS)
data(birthwt)
df <- birthwt[, -1]
dy <- df[, 8]
dx <- data.matrix(df[, -8])
cvfit <- cv.glmnet(x = dx, y = dy, family = "poisson", standardize = TRUE)
coef(cvfit, s = "lambda.min")
## 9 x 1 sparse Matrix of class "dgCMatrix"
## 1
## (Intercept) -0.8603852995
## age 0.0243126476
## lwt 0.0004265947
## race .
## smoke .
## ptl .
## ht .
## ui .
## bwt .
library(survival)
data(kidney)
kidney
## id time status age sex disease frail
## 1 1 8 1 28 1 Other 2.3
## 2 1 16 1 28 1 Other 2.3
## 3 2 23 1 48 2 GN 1.9
## 4 2 13 0 48 2 GN 1.9
## 5 3 22 1 32 1 Other 1.2
## 6 3 28 1 32 1 Other 1.2
## 7 4 447 1 31 2 Other 0.5
## 8 4 318 1 32 2 Other 0.5
## 9 5 30 1 10 1 Other 1.5
## 10 5 12 1 10 1 Other 1.5
## 11 6 24 1 16 2 Other 1.1
## 12 6 245 1 17 2 Other 1.1
## 13 7 7 1 51 1 GN 3.0
## 14 7 9 1 51 1 GN 3.0
## 15 8 511 1 55 2 GN 0.5
## 16 8 30 1 56 2 GN 0.5
## 17 9 53 1 69 2 AN 0.7
## 18 9 196 1 69 2 AN 0.7
## 19 10 15 1 51 1 GN 0.4
## 20 10 154 1 52 1 GN 0.4
## 21 11 7 1 44 2 AN 0.6
## 22 11 333 1 44 2 AN 0.6
## 23 12 141 1 34 2 Other 1.2
## 24 12 8 0 34 2 Other 1.2
## 25 13 96 1 35 2 AN 1.4
## 26 13 38 1 35 2 AN 1.4
## 27 14 149 0 42 2 AN 0.4
## 28 14 70 0 42 2 AN 0.4
## 29 15 536 1 17 2 Other 0.4
## 30 15 25 0 17 2 Other 0.4
## 31 16 17 1 60 1 AN 1.1
## 32 16 4 0 60 1 AN 1.1
## 33 17 185 1 60 2 Other 0.8
## 34 17 177 1 60 2 Other 0.8
## 35 18 292 1 43 2 Other 0.8
## 36 18 114 1 44 2 Other 0.8
## 37 19 22 0 53 2 GN 0.5
## 38 19 159 0 53 2 GN 0.5
## 39 20 15 1 44 2 Other 1.3
## 40 20 108 0 44 2 Other 1.3
## 41 21 152 1 46 1 PKD 0.2
## 42 21 562 1 47 1 PKD 0.2
## 43 22 402 1 30 2 Other 0.6
## 44 22 24 0 30 2 Other 0.6
## 45 23 13 1 62 2 AN 1.7
## 46 23 66 1 63 2 AN 1.7
## 47 24 39 1 42 2 AN 1.0
## 48 24 46 0 43 2 AN 1.0
## 49 25 12 1 43 1 AN 0.7
## 50 25 40 1 43 1 AN 0.7
## 51 26 113 0 57 2 AN 0.5
## 52 26 201 1 58 2 AN 0.5
## 53 27 132 1 10 2 GN 1.1
## 54 27 156 1 10 2 GN 1.1
## 55 28 34 1 52 2 AN 1.8
## 56 28 30 1 52 2 AN 1.8
## 57 29 2 1 53 1 GN 1.5
## 58 29 25 1 53 1 GN 1.5
## 59 30 130 1 54 2 GN 1.5
## 60 30 26 1 54 2 GN 1.5
## 61 31 27 1 56 2 AN 1.7
## 62 31 58 1 56 2 AN 1.7
## 63 32 5 0 50 2 AN 1.3
## 64 32 43 1 51 2 AN 1.3
## 65 33 152 1 57 2 PKD 2.9
## 66 33 30 1 57 2 PKD 2.9
## 67 34 190 1 44 2 GN 0.7
## 68 34 5 0 45 2 GN 0.7
## 69 35 119 1 22 2 Other 2.2
## 70 35 8 1 22 2 Other 2.2
## 71 36 54 0 42 2 Other 0.7
## 72 36 16 0 42 2 Other 0.7
## 73 37 6 0 52 2 PKD 2.1
## 74 37 78 1 52 2 PKD 2.1
## 75 38 63 1 60 1 PKD 1.2
## 76 38 8 0 60 1 PKD 1.2
y <- kidney$time
delta <- kidney$status
Surv(y, delta)
## [1] 8 16 23 13+ 22 28 447 318 30 12 24 245 7 9 511
## [16] 30 53 196 15 154 7 333 141 8+ 96 38 149+ 70+ 536 25+
## [31] 17 4+ 185 177 292 114 22+ 159+ 15 108+ 152 562 402 24+ 13
## [46] 66 39 46+ 12 40 113+ 201 132 156 34 30 2 25 130 26
## [61] 27 58 5+ 43 152 30 190 5+ 119 8 54+ 16+ 6+ 78 63
## [76] 8+
fit <- survfit(Surv(time, status) ~ disease, data = kidney)
plot(fit, xlab = "時間", ylab = "生存率", col = c("red", "green", "blue", "black"))
legend(300, 0.8, legend = c("その他", "GN", "AN", "PKD"),
lty = 1, col = c("red", "green", "blue", "black"))
cox.lasso <- function(X, y, delta, lambda = lambda) {
delta[1] <- 1
n <- length(y)
w <- array(dim = n)
u <- array(dim = n)
pi <- array(dim = c(n, n))
beta <- rnorm(p)
gamma <- rep(0, p)
while (sum((beta - gamma) ^ 2) > 10 ^ {-4}) {
beta <- gamma
s <- as.vector(X %*% beta)
v <- exp(s)
for (i in 1:n)
for (j in 1:n)
pi[i, j] <- v[i] / sum(v[j:n])
for (i in 1:n) {
u[i] <- delta[i]
w[i] <- 0
for (j in 1:i) {
if (delta[j] == 1) {
u[i] <- u[i] - pi[i, j]
w[i] <- w[i] + pi[i, j] * (1 - pi[i, j])
}
}
}
z <- s + u / w
W <- diag(w)
print(gamma)
gamma <- W.linear.lasso(X, z, W, lambda = lambda)[-1]
}
return(gamma)
}
df <- kidney
index <- order(df$time)
df <- df[index, ]
n <- nrow(df)
p <- 4
y <- as.numeric(df[[2]])
delta <- as.numeric(df[[3]])
X <- as.numeric(df[[4]])
for (j in 5:7)
X <- cbind(X, as.numeric(df[[j]]))
z <- Surv(y, delta)
cox.lasso(X, y, delta, 0)
## [1] 0 0 0 0
## [1] 0.0101287 -1.7747758 -0.3887608 1.3532378
## [1] 0.01462571 -1.69299527 -0.41598742 1.38980788
## [1] 0.01591941 -1.66769665 -0.42331475 1.40330234
## [1] 0.01628935 -1.66060178 -0.42528537 1.40862969
cox.lasso(X, y, delta, 0.1)
## [1] 0 0 0 0
## [1] 0.00000000 -1.04501150 -0.08845692 1.00600836
## [1] 0.00000000 -0.97761575 -0.05957538 0.97284124
## [1] 0.0000000 -0.9567328 -0.0530083 0.9592735
## [1] 0.00000000 -0.95074758 -0.05159338 0.95420530
cox.lasso(X, y, delta, 0.2)
## [1] 0 0 0 0
## [1] 0.0000000 -0.5312546 0.0000000 0.7201078
## [1] 0.0000000 -0.5090856 0.0000000 0.6857966
## [1] 0.0000000 -0.5048733 0.0000000 0.6768827
glmnet(X, z, family = "cox", lambda = 0.1)$beta
## 4 x 1 sparse Matrix of class "dgCMatrix"
## s0
## X .
## -0.87359015
## -0.05659599
## 0.92923820
library(survival)
load("LymphomaData.rda")
attach("LymphomaData.rda")
## The following object is masked _by_ .GlobalEnv:
##
## patient.data
names(patient.data)
## [1] "x" "time" "status"
x <- t(patient.data$x)
y <- patient.data$time
delta <- patient.data$status
Surv(y, delta)
## [1] 5.0+ 5.9+ 6.6+ 13.1+ 1.6 1.3 1.4 2.2 3.4 2.2 9.8+ 9.1+
## [13] 6.2+ 9.1+ 3.8 5.5 5.3+ 3.5 2.3 1.1 2.7 8.2+ 1.6 1.0
## [25] 1.3 1.3 3.5 11.8+ 8.7+ 6.9+ 17.9 10.7+ 11.1+ 2.6+ 1.8 4.9
## [37] 4.3 8.1+ 4.3+ 7.8+ 1.0 1.1 3.7 22.8+ 1.8 7.2 8.2+ 1.4
## [49] 8.2+ 7.0+ 2.0 2.7+ 1.6 3.3+ 6.4 5.1 15.4+ 15.2+ 2.1 1.7
## [61] 5.0+ 2.0 2.7+ 3.0+ 1.4 2.9+ 10.7+ 1.1 8.4+ 6.0+ 11.0+ 1.3
## [73] 2.3 6.5+ 4.4+ 7.9 2.9 1.3 11.5+ 1.4 1.3 3.6 1.3 2.3
## [85] 2.5 1.5 3.4 1.9 3.8 1.1 3.3+ 1.0 1.2 13.2+ 4.3+ 7.8
## [97] 3.5+ 2.4 2.0 8.6+ 2.6 10.2+ 5.2+ 2.1 6.9+ 1.7 1.7 4.9+
## [109] 1.3 3.8+ 14.3+ 9.4+ 2.0 11.3+ 6.6+ 14.8+ 17.8+ 1.9 6.6 1.8+
## [121] 3.7+ 3.8+ 2.0+ 5.8+ 7.7+ 1.2 10.1+ 1.7 1.3 1.3 7.7+ 6.8+
## [133] 8.1+ 3.3 7.5+ 10.7+ 1.0 15.6+ 3.9 7.6 3.3 12.6+ 1.2 1.7
## [145] 3.1 3.3 3.3 2.3 1.1 1.5 2.0 13.3+ 8.8+ 3.1 1.4 3.0
## [157] 7.5+ 2.0 8.3+ 7.0 12.0 11.6 2.0 6.4 1.8 11.5+ 10.6+ 10.0+
## [169] 8.4+ 8.5 12.3+ 3.0 1.4 12.4+ 4.0 9.4+ 1.2 6.0+ 1.9 1.6
## [181] 1.6 7.4+ 5.8+ 6.0 1.3 10.5+ 5.1+ 9.9+ 2.5 2.1 4.9+ 1.1
## [193] 2.3 16.4+ 10.1 5.3 1.1 4.6 1.4 11.2+ 11.4+ 1.7 1.0 1.7
## [205] 1.1 5.1+ 4.5+ 1.5 12.7+ 7.1 2.1 2.9 10.5+ 1.4 2.3 2.1
## [217] 1.7 2.6 4.4+ 10.2+ 2.2 12.2+ 2.4 9.8+ 4.0 1.4 3.9 1.4
## [229] 18.4+ 17.8+ 2.3 5.0+ 6.6 20.8+ 7.9 1.3 4.1 8.7+ 3.7 2.1
library(ranger)
## Warning: package 'ranger' was built under R version 4.0.3
library(ggplot2)
## Warning: package 'ggplot2' was built under R version 4.0.3
library(dplyr)
## Warning: package 'dplyr' was built under R version 4.0.3
##
## Attaching package: 'dplyr'
## The following object is masked from 'package:MASS':
##
## select
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
library(ggfortify)
## Warning: package 'ggfortify' was built under R version 4.0.3
cv.fit <- cv.glmnet(x, Surv(y, delta), family = "cox")
fit2 <- glmnet(x, Surv(y, delta), lambda = cv.fit$lambda.min, family = "cox")
z <- sign(drop(x %*% fit2$beta))
fit3 <- survfit(Surv(y, delta) ~ z)
autoplot(fit3)
mean(y[z == 1])
## [1] 3.146429
mean(y[z == -1])
## [1] 7.392188