Ch.5 MMD と HSIC(問題65~83)

In [ ]:
# 第 5 章のプログラムは,事前に下記が実行されていることを仮定する。
import numpy as np
from scipy.stats import kde
import itertools
import math
import matplotlib.pyplot as plt
from matplotlib import style
style.use("seaborn-ticks")

76

In [ ]:
def cc(x, y):
    return np.sum(np.dot(x.T, y)) / len(x)


def f(u, v):
    return u - cc(u, v) / cc(v, v) * v
In [ ]:
# Data generation
n = 30
x = np.random.randn(n)**2 - np.random.randn(n)**2
y = 2 * x + np.random.randn(n)**2 - np.random.randn(n)**2
z = x + y + np.random.randn(n)**2 - np.random.randn(n)**2
x = x - np.mean(x)
y = y - np.mean(y)
z = z - np.mean(z)


# 上流を推定
def cc(x, y):
    return np.sum(np.dot(x.T, y) / len(x))


def f(u, v):
    return u - cc(u, v) / cc(v, v) * v


x_y = f(x, y)
y_z = f(y, z)
z_x = f(z, x)
x_z = f(x, z)
z_y = f(z, y)
y_x = f(y, x)

v1 = HSIC_2(x, y_x, z_x, k_x, k_y, k_z)
v2 = HSIC_2(y, z_y, x_y, k_y, k_z, k_x)
v3 = HSIC_2(z, x_z, y_z, k_z, k_x, k_y)

if v1 < v2:
    if v1 < v3:
        top = 1
    else:
        top = 3
else:
    if v2 < v3:
        top = 2
    else:
        top = 3

# 下流を推定
x_yz = f(x_y, z_y)
y_zx = f(y_z, x_z)
z_xy = f(z_x, y_x)

if top == 1:
    v1 = ## 空欄(1) ##
    v2 = ## 空欄(2) ##
    if v1 < v2:
        middle = 2
        bottom = 3
    else:
        middle = 3
        bottom = 2
if top == 2:
    v1 = ## 空欄(3) ##
    v2 = ## 空欄(4) ##
    if v1 < v2:
        middle = 3
        bottom = 1
    else:
        middle = 1
        bottom = 3

if top == 3:
    v1 = ## 空欄(5) ##
    v2 = ## 空欄(6) ##
    if v1 < v2:
        middle = 1
        bottom = 2
    else:
        middle = 2
        bottom = 1

# 結果を出力
print("top =", top)
print("middle =", middle)
print("bottom =", bottom)

77

In [ ]:
# データの生成
x = np.random.randn(n)
y = np.random.randn(n)
u = HSIC_1(x, y, k_x, k_y)
m = 100
w = []
for i in range(m):
    x = x[np.random.choice(n, n, replace=False)]
    w.append(HSIC_1(x, y, k_x, k_y))
v = np.quantile(w, 0.95)
x = np.linspace(min(min(w), u, v), max(max(w), u, v), 200)

density = kde.gaussian_kde(w)
plt.plot(x, density(x))
plt.axvline(x=v, c="r", linestyle="--")
plt.axvline(x=u, c="b")