import dssvi_func
import ssvi_org
import time
import numpy as np
import matplotlib.pyplot as plt
import random
import pandas as pd
from scipy.stats import poisson
def plot_results(dssvi_obj, ssvi_obj):
plt.plot(dssvi_obj.time_list,
[(dssvi_obj.metrics_list[i][2]) for i in range(len(dssvi_obj.metrics_list))], label="DSSVI")
plt.plot(ssvi_obj.time_list,
[(ssvi_obj.metrics_list[i][2]) for i in range(len(ssvi_obj.metrics_list))], label="SSVI")
plt.legend()
plt.xlabel("time")
plt.ylabel("RRMSE")
plt.title("RRMSE across time")
plt.ylim(0, 1)
plt.show()
plt.plot(dssvi_obj.time_list,
[(dssvi_obj.metrics_list[i][3]) for i in range(len(dssvi_obj.metrics_list))], label="DSSVI")
plt.plot(ssvi_obj.time_list,
[(ssvi_obj.metrics_list[i][3]) for i in range(len(ssvi_obj.metrics_list))], label="SSVI")
plt.legend()
plt.xlabel("time")
plt.ylabel("Log RMSE")
plt.title("Log RMSE across time")
plt.ylim(0, 2)
plt.show()
return
def show_results(obj, X_true, dssvi=True):
# very small vaule, for log(0).
EPS = np.spacing(1)
if (dssvi == True):
# show # of K and log_ll plot
dssvi_func.K_and_log_ll(obj)
# plot Exp. pi
dssvi_func.plot_pi(obj, true_pi_W=true_pi_W, true_pi_H=true_pi_H)
# Calc. posterior means
W_post_mean, H_post_mean, X_post_mean = dssvi_func.get_post_means(obj)
# draw matrices
dssvi_func.draw_two_matrices(W_S_W, W_post_mean * obj.S_W[:, obj.good_k], 1, 1, 'true W * S_W', 'E(W) * S_W')
dssvi_func.draw_two_matrices(H_S_H, H_post_mean * obj.S_H[obj.good_k] , 1, 1, 'true H * S_H', 'E(H) * S_H')
dssvi_func.draw_two_matrices(X_true, X_post_mean, 1, 1, first='True X', second='Reconstructed X')
# draw log
X_post_mean = pd.DataFrame.from_records(X_post_mean)
X_true = pd.DataFrame.from_records(X_true)
dssvi_func.draw_two_matrices(np.log(X_true.replace(0, np.nan)).replace(np.nan, 0),
np.log(X_post_mean.replace(0, np.nan)).replace(np.nan, 0), 1, 1,
'log(X true)', 'log(X_reconst) (0s not taken log)')
#dssvi_func.draw_two_matrices(np.log(X_true + EPS), np.log(X_post_mean + EPS), 1, 1, 'log(True X)', 'log(Reconst. X)')
else:
# show # of K and log_ll plot
ssvi_org.K_and_log_ll(obj)
# plot Exp. pi
ssvi_org.plot_pi(obj, true_pi_H=true_pi_H)
# Calc. posterior means
W_post_mean, H_post_mean, X_post_mean = ssvi_org.get_post_means(obj)
# draw matrices
ssvi_org.draw_two_matrices(W_S_W, W_post_mean, 1, 1, 'true W * S_W', 'E(W)')
ssvi_org.draw_two_matrices(H_S_H, H_post_mean * obj.S[obj.good_k] , 1, 1, 'true H * S_H', 'E(H) * S_H')
ssvi_org.draw_two_matrices(X_true, X_post_mean, 1, 1, first='True X', second='Reconstructed X')
X_post_mean = pd.DataFrame.from_records(X_post_mean)
X_true = pd.DataFrame.from_records(X_true)
ssvi_org.draw_two_matrices(np.log(X_true.replace(0, np.nan)).replace(np.nan, 0),
np.log(X_post_mean.replace(0, np.nan)).replace(np.nan, 0), 1, 1,
'log(X true)', 'log(X_reconst) (0s not taken log)')
#ssvi_org.draw_two_matrices(np.log(X_true + EPS), np.log(X_post_mean + EPS), 1, 1, 'log(True X)', 'log(Reconst. X)')
return
def RRMSE_RlogMSE(dssvi_store, ssvi_store):
fig, ax = plt.subplots(len(dssvi_store), 2, figsize=(14,14))
for i in range(len(dssvi_store)):
for j in range(2):
ax[i, j].plot(dssvi_store[i].time_list,
[(dssvi_store[i].metrics_list[k][j+2]) for k in range(len(dssvi_store[i].metrics_list))], label="DSSVI")
ax[i, j].plot(ssvi_store[i].time_list,
[(ssvi_store[i].metrics_list[k][j+2]) for k in range(len(ssvi_store[i].metrics_list))], label="SSVI")
if j == 0:
ax[i, j].set_title(f"RRMSE, run={i+1}", fontsize=14)
ax[i,j].set_ylim([0,1])
else:
ax[i,j].set_title(f"RlogMSE, run={i+1}", fontsize=14)
ax[i,j].set_ylim([0,4])
ax[i,j].set_xlabel("time (sec)", fontsize=14)
ax[i,j].legend()
fig.tight_layout(pad=3.0)
plt.show()
return
true_pi_W = [0.3, 0.2, 0.3, 1]
true_pi_H = [0.4, 0.5, 0.45, 1]
S_W = np.vstack((np.hstack((np.ones(30), np.zeros(70))),
np.hstack((np.zeros(20), np.ones(20), np.zeros(60))),
np.hstack((np.zeros(50), np.ones(30), np.zeros(20))),
np.hstack((np.ones(100))))).T
S_H = np.vstack((np.hstack((np.ones(40), np.zeros(60))),
np.hstack((np.zeros(20), np.ones(50), np.zeros(30))),
np.hstack((np.zeros(55), np.ones(45))),
np.hstack(np.ones(100))))
S_W = S_W.repeat(6,axis=0)
dssvi_func.show_true_S(S_W, S_H)
np.random.seed(10)
cluster1 = np.random.gamma(1.6, 1, 600) # entries of cluster 1 of W and H each have mean 8
cluster2 = np.random.gamma(1.4, 1, 600) # entries of cluster 2 of W and H have mean 5
cluster3 = np.random.gamma(1.2, 1, 600) # entries of cluster 3 of W and H have mean 2
dense = np.random.gamma(1,1,10000)
clusters = [1,1,1,1]
clusters[0], clusters[1], clusters[2], clusters[3] = cluster1, cluster2, cluster3, dense
# populate W where entries for S_W is not 0
W_S_W = np.zeros((S_W.shape))
for i in range(S_W.shape[0]):
for j in range(S_W.shape[1]):
if S_W[i][j] != 0:
W_S_W[i][j] = clusters[j][i]
# simiarly for H_S_H
H_S_H = np.zeros((S_H.shape))
for i in range(S_H.shape[0]):
for j in range(S_H.shape[1]):
if S_H[i][j] != 0:
H_S_H[i][j] = clusters[i][j]
def show_true_combined():
plt.figure(figsize = (4,5))
plt.imshow(W_S_W, aspect="auto", interpolation="none")
plt.colorbar()
plt.figure(figsize = (6,3))
plt.imshow(H_S_H, aspect="auto", interpolation="none")
plt.colorbar()
show_true_combined()
X_true_k4 = W_S_W.dot(H_S_H)
plt.figure(figsize = (5,5))
plt.imshow(X_true_k4, aspect="auto", interpolation="none")
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x1ede37e3e08>
X_true_k4 = pd.DataFrame.from_records(X_true_k4)
X_true_log = np.log(X_true_k4.replace(0, np.nan)).replace(np.nan, 0)
dssvi_func.draw_two_matrices(X_true_k4, X_true_log, 1, 1, 'X_true', "log(X_true) (0's not taken log)")
X_1_k4 = poisson.rvs(X_true_k4, size=X_true_k4.shape)
X_1_k4
array([[18, 7, 2, ..., 0, 0, 0], [ 9, 5, 5, ..., 0, 0, 0], [ 6, 7, 2, ..., 0, 0, 0], ..., [ 0, 0, 0, ..., 0, 0, 0], [ 0, 0, 0, ..., 0, 0, 0], [ 0, 0, 0, ..., 0, 0, 0]])
params = [[5, 5, 5, 5], # a, b, c, d hyperparameters # Run 1
[10, 10, 10, 10], # Run 2
[1, 1, 1, 1]] # Run 3
ssvi_store = [0 for i in range(len(params))]
dssvi_store = [0 for i in range(len(params))]
# Run simulations using given parameters
for num in range(len(params)):
print("Run: " + str(num))
param = params[num]
ssvi_store[num] = ssvi_org.SSMF_BP_NMF(n_components=4, burn_in=20, random_state=10, verbose=False,
max_iter=50, cutoff=1e-2, X_true=X_true_k4,
a=param[0], b=param[1], c=param[2], d=param[3])
dssvi_store[num] = dssvi_func.SSMF_BP_NMF(n_components=4, burn_in=20, random_state=10, verbose=False,
max_iter=50, cutoff=1e-2, X_true=X_true_k4,
a=param[0], b=param[1], c=param[2], d=param[3])
print("\t ssvi...")
ssvi_store[num].fit(X_1_k4)
print("\t dssvi...")
dssvi_store[num].fit(X_1_k4)
Run: 0 ssvi... dssvi... Run: 1 ssvi... dssvi... Run: 2 ssvi... dssvi...
RRMSE_RlogMSE(dssvi_store, ssvi_store)
show_results(dssvi_store[2], X_true=X_true_k4, dssvi=True)
[-127414.33186326 -98824.51334246 -97902.95416128 -97058.18308324 -95725.64339642 -94062.57062185 -91246.23772688 -89197.57961338 -87404.72623926 -86269.8066911 -85367.72334341 -84612.99252845 -84067.76425987 -83531.59706251 -83166.51473352 -82660.002681 -82167.11897145 -81868.80789714 -81286.50443718 -80626.19841381 -80254.2125516 -79831.13513589 -79522.61917274 -79105.09372157 -78934.0256242 -78597.02752567 -78353.72037559 -78038.08053033 -77844.76183216 -77574.39469152 -77334.94290028 -77137.58750796 -76909.71804728 -76607.5548623 -76394.24337328 -76165.50801186 -75931.01340064 -75777.55396724 -75545.8266704 -75260.37603479 -75076.70171776 -74898.60939691 -74704.23430538 -74415.33519171 -74334.21739262 -74116.46800339 -73911.20730215 -73760.33640381 -73656.39288997 -73515.33324417]
show_results(ssvi_store[2], X_true=X_true_k4, dssvi=False)
[-132043.34431192 -96154.43046107 -94972.73676987 -93039.03558419 -90633.15407096 -88539.64135547 -86783.32362342 -85374.5577818 -84062.35179766 -82877.81438213 -81607.88037224 -80666.3474229 -79768.12200623 -78803.05478515 -78016.59917117 -77128.96662992 -76484.31069821 -75800.82805133 -75194.44612423 -74822.78817809 -74495.38818069 -74133.79876432 -73884.80365052 -73601.47520622 -73351.77431087 -73252.64537268 -73061.75675649 -72826.654027 -72761.30358704 -72589.25472247 -72542.96644305 -72441.13739005 -72221.51510657 -72199.49353792 -72053.38455572 -72046.87054609 -71837.62681256 -71845.05664258 -71714.45273264 -71645.61197899 -71606.42007164 -71520.46695288 -71495.37035537 -71449.17484713 -71388.68470043 -71265.81491825 -71194.68660821 -71142.63272494 -71080.60270439 -71103.21622598]
plt.figure(figsize=(8,5))
plt.plot(dssvi_store[0].time_list,
[(dssvi_store[0].metrics_list[i][2]) for i in range(len(dssvi_store[0].metrics_list))], color="b", label="DSSVI")
plt.plot(ssvi_store[0].time_list,
[(ssvi_store[0].metrics_list[i][2]) for i in range(len(ssvi_store[0].metrics_list))], color="r", label="SSVI")
for j in range(1, len(dssvi_store)):
plt.plot(dssvi_store[j].time_list,
[(dssvi_store[j].metrics_list[i][2]) for i in range(len(dssvi_store[j].metrics_list))], color="b")
plt.plot(ssvi_store[j].time_list,
[(ssvi_store[j].metrics_list[i][2]) for i in range(len(ssvi_store[j].metrics_list))], color="r")
plt.legend()
plt.xlabel("time (sec)", fontsize=16)
plt.ylabel("RRMSE", fontsize=16)
plt.title("RRMSE over time", fontsize=16)
plt.show()
plt.figure(figsize=(8,5))
plt.plot(dssvi_store[0].time_list,
[(dssvi_store[0].metrics_list[i][3]) for i in range(len(dssvi_store[0].metrics_list))], color="b", label="DSSVI")
plt.plot(ssvi_store[0].time_list,
[(ssvi_store[0].metrics_list[i][3]) for i in range(len(ssvi_store[0].metrics_list))], color="r", label="SSVI")
for j in range(1, len(dssvi_store)):
plt.plot(dssvi_store[j].time_list,
[(dssvi_store[j].metrics_list[i][3]) for i in range(len(dssvi_store[j].metrics_list))], color="b")
plt.plot(ssvi_store[j].time_list,
[(ssvi_store[j].metrics_list[i][3]) for i in range(len(ssvi_store[j].metrics_list))], color="r")
plt.legend()
plt.xlabel("time (sec)", fontsize=16)
plt.ylabel("R_log_MSE", fontsize=16)
plt.title("R_log_MSE over time", fontsize=16)
plt.show()
true_pi_W = [0.3, 0.4, 0.5]
true_pi_H = [0.4, 0.5, 0.45]
S_W = np.vstack((np.hstack((np.ones(30), np.zeros(70))),
np.hstack((np.zeros(20), np.ones(40), np.zeros(40))),
np.hstack((np.zeros(50), np.ones(50))))).T
S_H = np.vstack((np.hstack((np.ones(40), np.zeros(60))),
np.hstack((np.zeros(20), np.ones(50), np.zeros(30))),
np.hstack((np.zeros(55), np.ones(45)))))
dssvi_func.show_true_S(S_W, S_H)
np.random.seed(10)
cluster1 = np.random.gamma(2, 1, 50) # entries of cluster 1 of W and H each have mean 8
cluster2 = np.random.gamma(1.5, 1, 50) # entries of cluster 2 of W and H have mean 5
cluster3 = np.random.gamma(1, 1, 50) # entries of cluster 3 of W and H have mean 2
W_S_W = np.zeros((S_W.shape))
W_S_W[:30, 0] = S_W[:30, 0] * cluster1[:30]
W_S_W[20:60, 1] = S_W[20:60, 1] * cluster2[:40]
W_S_W[50:, 2] = S_W[50:, 2] * cluster3[:50]
H_S_H = np.zeros((S_H.shape))
H_S_H[0, :40] = S_H[0, :40] * cluster1[:40]
H_S_H[1, 20:70] = S_H[1, 20:70] * cluster2[:50]
H_S_H[2, 55:] = S_H[2, 55:] * cluster3[:45]
show_true_combined()
X_true_k3 = np.dot(W_S_W, H_S_H)
X_1_k3 = poisson.rvs(X_true_k3, size=X_true_k3.shape)
dssvi_func.draw_two_matrices(X_true_k3, X_1_k3, 1, 1)
params = [[5, 5, 5, 5],
[10, 10, 10, 10],
[1, 1, 1, 1]]
ssvi_store_k3 = [0 for i in range(len(params))]
dssvi_store_k3 = [0 for i in range(len(params))]
# Run simulations using given parameters
for num in range(len(params)):
print("Run: " + str(num))
param = params[num]
ssvi_store_k3[num] = ssvi_org.SSMF_BP_NMF(n_components=3, burn_in=20, random_state=10, verbose=False,
max_iter=50, cutoff=1e-2, X_true=X_true_k3,
a=param[0], b=param[1], c=param[2], d=param[3])
dssvi_store_k3[num] = dssvi_func.SSMF_BP_NMF(n_components=3, burn_in=20, random_state=10, verbose=False,
max_iter=50, cutoff=1e-2, X_true=X_true_k3,
a=param[0], b=param[1], c=param[2], d=param[3])
print("\t ssvi...")
ssvi_store_k3[num].fit(X_1_k3)
print("\t dssvi...")
dssvi_store_k3[num].fit(X_1_k3)
Run: 0 ssvi... dssvi... Run: 1 ssvi... dssvi... Run: 2 ssvi... dssvi...
RRMSE_RlogMSE(dssvi_store_k3, ssvi_store_k3)
plt.figure(figsize=(8,5))
plt.plot(dssvi_store_k3[0].time_list,
[(dssvi_store_k3[0].metrics_list[i][2]) for i in range(len(dssvi_store_k3[0].metrics_list))],
color="b", label="DSSVI")
plt.plot(ssvi_store_k3[0].time_list,
[(ssvi_store_k3[0].metrics_list[i][2]) for i in range(len(ssvi_store_k3[0].metrics_list))],
color="r", label="SSVI")
for j in range(1, len(dssvi_store)):
plt.plot(dssvi_store_k3[j].time_list,
[(dssvi_store_k3[j].metrics_list[i][2]) for i in range(len(dssvi_store_k3[j].metrics_list))], color="b")
plt.plot(ssvi_store_k3[j].time_list,
[(ssvi_store_k3[j].metrics_list[i][2]) for i in range(len(ssvi_store_k3[j].metrics_list))], color="r")
plt.legend()
plt.xlabel("time (sec)", fontsize=16)
plt.ylabel("RRMSE", fontsize=16)
plt.title("RRMSE over time", fontsize=16)
plt.show()
plt.figure(figsize=(8,5))
plt.plot(dssvi_store_k3[0].time_list,
[(dssvi_store_k3[0].metrics_list[i][3]) for i in range(len(dssvi_store_k3[0].metrics_list))], color="b", label="DSSVI")
plt.plot(ssvi_store_k3[0].time_list,
[(ssvi_store_k3[0].metrics_list[i][3]) for i in range(len(ssvi_store_k3[0].metrics_list))], color="r", label="SSVI")
for j in range(1, len(dssvi_store)):
plt.plot(dssvi_store_k3[j].time_list,
[(dssvi_store_k3[j].metrics_list[i][3]) for i in range(len(dssvi_store_k3[j].metrics_list))], color="b")
plt.plot(ssvi_store_k3[j].time_list,
[(ssvi_store_k3[j].metrics_list[i][3]) for i in range(len(ssvi_store_k3[j].metrics_list))], color="r")
plt.legend()
plt.xlabel("time (sec)", fontsize=16)
plt.ylabel("R_log_MSE", fontsize=16)
plt.ylim(0,2)
plt.title("R_log_MSE over time", fontsize=16)
plt.show()
show_results(dssvi_store_k3[2], X_true=X_true_k3, dssvi=True)
[-21258.29961476 -15018.46233748 -13608.79692695 -11871.7201569 -11182.40623411 -10838.00703002 -10631.50818569 -10417.51229203 -10198.62818884 -9330.23974047 -8503.73701455 -8089.4842343 -7906.89536895 -7764.11611607 -7713.19199425 -7699.32622879 -7650.45312003 -7631.27508839 -7611.52669743 -7582.915239 -7611.09895601 -7597.15835696 -7584.37688339 -7566.79554671 -7558.1669938 -7550.29791343 -7565.8048475 -7556.82477067 -7559.84403592 -7560.49695759 -7545.81947514 -7550.45067847 -7527.94909949 -7535.16748531 -7536.37900285 -7556.79071856 -7533.06664768 -7535.79920217 -7534.64682549 -7538.96089283 -7552.58144524 -7569.0896996 -7548.68487458 -7546.97902738 -7566.15262317 -7543.86934556 -7551.41057797 -7540.50031123 -7541.36793068 -7560.18086541]
show_results(ssvi_store_k3[2], X_true=X_true_k3, dssvi=False)
[-24049.62347957 -15106.30313159 -14668.48186427 -13905.7210194 -12967.19490507 -11933.69547855 -11333.73966011 -10800.84434936 -10506.0934762 -10230.72857045 -9826.51779072 -9648.74373212 -9396.01506607 -9187.78924771 -8947.54856247 -8719.8807277 -8620.73298005 -8480.7888786 -8394.10666605 -8253.74711612 -8177.53658052 -8128.2726838 -8032.87760751 -7990.38000706 -7901.2376698 -7912.29562891 -7822.00146916 -7807.82228818 -7755.49640681 -7747.05343732 -7730.23251742 -7702.35967381 -7660.50408862 -7631.67500684 -7629.15480147 -7643.78308892 -7604.28321939 -7616.00998821 -7643.05413107 -7580.29299539 -7598.16363094 -7584.82508271 -7566.60903744 -7571.87238108 -7566.93826795 -7575.36004568 -7559.86994124 -7562.56711771 -7542.94226482 -7512.66171795]