# Data generation
n = 30
x = np.random.randn(n)**2 - np.random.randn(n)**2
y = 2 * x + np.random.randn(n)**2 - np.random.randn(n)**2
z = x + y + np.random.randn(n)**2 - np.random.randn(n)**2
x = x - np.mean(x)
y = y - np.mean(y)
z = z - np.mean(z)
# 上流を推定
def cc(x, y):
return np.sum(np.dot(x.T, y) / len(x))
def f(u, v):
return u - cc(u, v) / cc(v, v) * v
x_y = f(x, y)
y_z = f(y, z)
z_x = f(z, x)
x_z = f(x, z)
z_y = f(z, y)
y_x = f(y, x)
v1 = HSIC_2(x, y_x, z_x, k_x, k_y, k_z)
v2 = HSIC_2(y, z_y, x_y, k_y, k_z, k_x)
v3 = HSIC_2(z, x_z, y_z, k_z, k_x, k_y)
if v1 < v2:
if v1 < v3:
top = 1
else:
top = 3
else:
if v2 < v3:
top = 2
else:
top = 3
# 下流を推定
x_yz = f(x_y, z_y)
y_zx = f(y_z, x_z)
z_xy = f(z_x, y_x)
if top == 1:
v1 = ## 空欄(1) ##
v2 = ## 空欄(2) ##
if v1 < v2:
middle = 2
bottom = 3
else:
middle = 3
bottom = 2
if top == 2:
v1 = ## 空欄(3) ##
v2 = ## 空欄(4) ##
if v1 < v2:
middle = 3
bottom = 1
else:
middle = 1
bottom = 3
if top == 3:
v1 = ## 空欄(5) ##
v2 = ## 空欄(6) ##
if v1 < v2:
middle = 1
bottom = 2
else:
middle = 2
bottom = 1
# 結果を出力
print("top =", top)
print("middle =", middle)
print("bottom =", bottom)