import dssvi_func
import ssvi_org
import time
import numpy as np
import matplotlib.pyplot as plt
import random
import pandas as pd
from scipy.stats import poisson
def plot_results(dssvi_obj, ssvi_obj):
plt.plot(dssvi_obj.time_list,
[(dssvi_obj.metrics_list[i][2]) for i in range(len(dssvi_obj.metrics_list))], label="DSSVI")
plt.plot(ssvi_obj.time_list,
[(ssvi_obj.metrics_list[i][2]) for i in range(len(ssvi_obj.metrics_list))], label="SSVI")
plt.legend()
plt.xlabel("time")
plt.ylabel("RRMSE")
plt.title("RRMSE across time")
plt.ylim(0, 1)
plt.show()
plt.plot(dssvi_obj.time_list,
[(dssvi_obj.metrics_list[i][3]) for i in range(len(dssvi_obj.metrics_list))], label="DSSVI")
plt.plot(ssvi_obj.time_list,
[(ssvi_obj.metrics_list[i][3]) for i in range(len(ssvi_obj.metrics_list))], label="SSVI")
plt.legend()
plt.xlabel("time")
plt.ylabel("Log RMSE")
plt.title("Log RMSE across time")
plt.ylim(0, 2)
plt.show()
return
def show_results(obj, X_true, dssvi=True):
# very small vaule, for log(0).
EPS = np.spacing(1)
if (dssvi == True):
# show # of K and log_ll plot
dssvi_func.K_and_log_ll(obj)
# plot Exp. pi
dssvi_func.plot_pi(obj, true_pi_W=true_pi_W, true_pi_H=true_pi_H)
# Calc. posterior means
W_post_mean, H_post_mean, X_post_mean = dssvi_func.get_post_means(obj)
# draw matrices
dssvi_func.draw_two_matrices(W_S_W, W_post_mean * obj.S_W[:, obj.good_k], 1, 1, 'true W * S_W', 'E(W) * S_W')
dssvi_func.draw_two_matrices(H_S_H, H_post_mean * obj.S_H[obj.good_k] , 1, 1, 'true H * S_H', 'E(H) * S_H')
dssvi_func.draw_two_matrices(X_true, X_post_mean, 1, 1, first='True X', second='Reconstructed X')
# draw log
X_post_mean = pd.DataFrame.from_records(X_post_mean)
X_true = pd.DataFrame.from_records(X_true)
dssvi_func.draw_two_matrices(np.log(X_true.replace(0, np.nan)).replace(np.nan, 0),
np.log(X_post_mean.replace(0, np.nan)).replace(np.nan, 0), 1, 1,
'log(X true)', 'log(X_reconst) (0s not taken log)')
#dssvi_func.draw_two_matrices(np.log(X_true + EPS), np.log(X_post_mean + EPS), 1, 1, 'log(True X)', 'log(Reconst. X)')
else:
# show # of K and log_ll plot
ssvi_org.K_and_log_ll(obj)
# plot Exp. pi
ssvi_org.plot_pi(obj, true_pi_H=true_pi_H)
# Calc. posterior means
W_post_mean, H_post_mean, X_post_mean = ssvi_org.get_post_means(obj)
# draw matrices
ssvi_org.draw_two_matrices(W_S_W, W_post_mean, 1, 1, 'true W * S_W', 'E(W)')
ssvi_org.draw_two_matrices(H_S_H, H_post_mean * obj.S[obj.good_k] , 1, 1, 'true H * S_H', 'E(H) * S_H')
ssvi_org.draw_two_matrices(X_true, X_post_mean, 1, 1, first='True X', second='Reconstructed X')
X_post_mean = pd.DataFrame.from_records(X_post_mean)
X_true = pd.DataFrame.from_records(X_true)
ssvi_org.draw_two_matrices(np.log(X_true.replace(0, np.nan)).replace(np.nan, 0),
np.log(X_post_mean.replace(0, np.nan)).replace(np.nan, 0), 1, 1,
'log(X true)', 'log(X_reconst) (0s not taken log)')
#ssvi_org.draw_two_matrices(np.log(X_true + EPS), np.log(X_post_mean + EPS), 1, 1, 'log(True X)', 'log(Reconst. X)')
return
def RRMSE_RlogMSE(dssvi_store, ssvi_store):
fig, ax = plt.subplots(len(dssvi_store), 2, figsize=(14,14))
for i in range(len(dssvi_store)):
for j in range(2):
ax[i, j].plot(dssvi_store[i].time_list,
[(dssvi_store[i].metrics_list[k][j+2]) for k in range(len(dssvi_store[i].metrics_list))], label="DSSVI")
ax[i, j].plot(ssvi_store[i].time_list,
[(ssvi_store[i].metrics_list[k][j+2]) for k in range(len(ssvi_store[i].metrics_list))], label="SSVI")
if j == 0:
ax[i, j].set_title(f"RRMSE, run={i+1}", fontsize=14)
ax[i,j].set_ylim([0,1])
else:
ax[i,j].set_title(f"RlogMSE, run={i+1}", fontsize=14)
ax[i,j].set_ylim([0,4])
ax[i,j].set_xlabel("time (sec)", fontsize=14)
ax[i,j].legend()
fig.tight_layout(pad=3.0)
plt.show()
return
true_pi_W = [0.8, 0.9, 0.8, 1]
true_pi_H = [0.4, 0.3, 0.3, 1]
S_W = np.vstack((np.hstack((np.ones(80), np.zeros(20))),
np.hstack((np.zeros(10), np.ones(90))),
np.hstack((np.ones(80), np.zeros(20))),
np.hstack((np.ones(100))))).T
S_H = np.vstack((np.hstack((np.ones(40), np.zeros(60))),
np.hstack((np.zeros(30), np.ones(30), np.zeros(40))),
np.hstack((np.zeros(70), np.ones(30))),
np.hstack(np.ones(100))))
dssvi_func.show_true_S(S_W, S_H)
np.random.seed(10)
cluster1 = np.random.gamma(1.6, 1, 600) # entries of cluster 1 of W and H each have mean 8
cluster2 = np.random.gamma(1.4, 1, 600) # entries of cluster 2 of W and H have mean 5
cluster3 = np.random.gamma(1.2, 1, 600) # entries of cluster 3 of W and H have mean 2
dense = np.random.gamma(1,1,10000)
clusters = [1,1,1,1]
clusters[0], clusters[1], clusters[2], clusters[3] = cluster1, cluster2, cluster3, dense
# populate W where entries for S_W is not 0
W_S_W = np.zeros((S_W.shape))
for i in range(S_W.shape[0]):
for j in range(S_W.shape[1]):
if S_W[i][j] != 0:
W_S_W[i][j] = clusters[j][i]
# simiarly for H_S_H
H_S_H = np.zeros((S_H.shape))
for i in range(S_H.shape[0]):
for j in range(S_H.shape[1]):
if S_H[i][j] != 0:
H_S_H[i][j] = clusters[i][j]
def show_true_combined():
plt.figure(figsize = (4,5))
plt.imshow(W_S_W, aspect="auto", interpolation="none")
plt.colorbar()
plt.figure(figsize = (6,3))
plt.imshow(H_S_H, aspect="auto", interpolation="none")
plt.colorbar()
show_true_combined()
X_true = W_S_W.dot(H_S_H)
plt.figure(figsize = (5,5))
plt.imshow(X_true, aspect="auto", interpolation="none")
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x1d36b7de108>
X_true = pd.DataFrame.from_records(X_true)
X_true_log = np.log(X_true.replace(0, np.nan)).replace(np.nan, 0)
dssvi_func.draw_two_matrices(X_true, X_true_log, 1, 1, 'X_true', "log(X_true) (0's not taken log)")
X_1 = poisson.rvs(X_true, size=X_true.shape)
params = [[5, 5, 5, 5],
[10, 10, 10, 10],
[1, 1, 1, 1]]
ssvi_store = [0 for i in range(len(params))]
dssvi_store = [0 for i in range(len(params))]
# Run simulations using given parameters
for num in range(len(params)):
print("Run: " + str(num))
param = params[num]
ssvi_store[num] = ssvi_org.SSMF_BP_NMF(n_components=4, burn_in=20, random_state=10, verbose=False,
max_iter=50, cutoff=1e-2, X_true=X_true,
a=param[0], b=param[1], c=param[2], d=param[3])
dssvi_store[num] = dssvi_func.SSMF_BP_NMF(n_components=4, burn_in=20, random_state=10, verbose=False,
max_iter=50, cutoff=1e-2, X_true=X_true,
a=param[0], b=param[1], c=param[2], d=param[3])
print("\t ssvi...")
ssvi_store[num].fit(X_1)
print("\t dssvi...")
dssvi_store[num].fit(X_1)
Run: 0 ssvi... dssvi... Run: 1 ssvi... dssvi... Run: 2 ssvi... dssvi...
RRMSE_RlogMSE(dssvi_store, ssvi_store)
plt.figure(figsize=(8,5))
plt.plot(dssvi_store[0].time_list,
[(dssvi_store[0].metrics_list[i][2]) for i in range(len(dssvi_store[0].metrics_list))], color="b", label="DSSVI")
plt.plot(ssvi_store[0].time_list,
[(ssvi_store[0].metrics_list[i][2]) for i in range(len(ssvi_store[0].metrics_list))], color="r", label="SSVI")
for j in range(1, len(dssvi_store)):
plt.plot(dssvi_store[j].time_list,
[(dssvi_store[j].metrics_list[i][2]) for i in range(len(dssvi_store[j].metrics_list))], color="b")
plt.plot(ssvi_store[j].time_list,
[(ssvi_store[j].metrics_list[i][2]) for i in range(len(ssvi_store[j].metrics_list))], color="r")
plt.legend()
plt.xlabel("time (sec)", fontsize=16)
plt.ylabel("RRMSE", fontsize=16)
plt.title("RRMSE over time", fontsize=16)
plt.show()
show_results(dssvi_store[0], X_true=X_true, dssvi=True)
[ -inf -19992.78635358 -19436.48842835 -19131.19087031 -18877.19134723 -18662.16836296 -18542.69044467 -18383.23146882 -18263.46928415 -18114.68241598 -18013.02459675 -17889.23269055 -17805.09651232 -17592.39246929 -17403.80106531 -17273.20829014 -17239.02483274 -17069.44162971 -16998.79151325 -16880.70872555 -16882.07069508 -16822.8717277 -16751.37693726 -16727.77705438 -16684.68599583 -16658.59646093 -16658.27811024 -16600.39305558 -16590.44620887 -16558.88748499 -16535.8211968 -16495.22677463 -16480.39763302 -16475.3365043 -16453.86196249 -16475.88411761 -16434.50965343 -16415.32036951 -16449.93417809 -16387.68329827 -16380.74816826 -16363.30867857 -16369.6239304 -16315.18712196 -16317.8747347 -16312.33834394 -16318.32468755 -16321.41111686 -16305.93093003 -16279.47246668]
show_results(ssvi_store[0], X_true=X_true, dssvi=False)
[-27108.64591072 -20229.08094292 -19938.71383253 -19813.02606838 -19683.5229818 -19602.67789796 -19477.1363875 -19384.76102278 -19198.07884605 -19113.47242741 -18939.93273445 -18751.86120124 -18596.21197244 -18419.07473069 -18292.60165557 -18257.31572759 -18054.45648089 -17889.46088585 -17802.65920288 -17720.77930906 -17633.62431593 -17557.23571682 -17452.40167091 -17399.68852073 -17375.94954496 -17307.88737345 -17255.72517548 -17197.0707438 -17164.6325927 -17126.14344098 -17065.23353405 -17015.69860102 -17014.62748204 -16952.26057017 -16931.50410923 -16877.22355746 -16896.55058729 -16850.14692956 -16830.3374694 -16815.5305069 -16796.75719138 -16767.65653772 -16786.27973141 -16730.8411949 -16726.66492692 -16712.45550115 -16698.65846017 -16704.0677061 -16665.32180052 -16669.28072212]
plt.figure(figsize=(8,5))
plt.plot(dssvi_store[0].time_list,
[(dssvi_store[0].metrics_list[i][3]) for i in range(len(dssvi_store[0].metrics_list))], color="b", label="DSSVI")
plt.plot(ssvi_store[0].time_list,
[(ssvi_store[0].metrics_list[i][3]) for i in range(len(ssvi_store[0].metrics_list))], color="r", label="SSVI")
for j in range(1, len(dssvi_store)):
plt.plot(dssvi_store[j].time_list,
[(dssvi_store[j].metrics_list[i][3]) for i in range(len(dssvi_store[j].metrics_list))], color="b")
plt.plot(ssvi_store[j].time_list,
[(ssvi_store[j].metrics_list[i][3]) for i in range(len(ssvi_store[j].metrics_list))], color="r")
plt.legend()
plt.xlabel("time (sec)", fontsize=16)
plt.ylabel("R_log_MSE", fontsize=16)
plt.title("R_log_MSE over time", fontsize=16)
plt.show()
true_pi_W = [0.8, 0.9, 0.8]
true_pi_H = [0.4, 0.3, 0.3]
S_W = np.vstack((np.hstack((np.ones(80), np.zeros(20))),
np.hstack((np.zeros(10), np.ones(90))),
np.hstack((np.ones(80), np.zeros(20))))).T
S_H = np.vstack((np.hstack((np.ones(40), np.zeros(60))),
np.hstack((np.zeros(30), np.ones(30), np.zeros(40))),
np.hstack((np.zeros(70), np.ones(30)))))
dssvi_func.show_true_S(S_W, S_H)
np.random.seed(10)
cluster1 = np.random.gamma(1.6, 1, 600) # entries of cluster 1 of W and H each have mean 8
cluster2 = np.random.gamma(1.4, 1, 600) # entries of cluster 2 of W and H have mean 5
cluster3 = np.random.gamma(1.2, 1, 600) # entries of cluster 3 of W and H have mean 2
clusters = [1,1,1]
clusters[0], clusters[1], clusters[2] = cluster1, cluster2, cluster3
# populate W where entries for S_W is not 0
W_S_W = np.zeros((S_W.shape))
for i in range(S_W.shape[0]):
for j in range(S_W.shape[1]):
if S_W[i][j] != 0:
W_S_W[i][j] = clusters[j][i]
# simiarly for H_S_H
H_S_H = np.zeros((S_H.shape))
for i in range(S_H.shape[0]):
for j in range(S_H.shape[1]):
if S_H[i][j] != 0:
H_S_H[i][j] = clusters[i][j]
show_true_combined()
X_true_2 = W_S_W.dot(H_S_H)
plt.figure(figsize = (5,5))
plt.imshow(X_true_2, aspect="auto", interpolation="none")
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x1d36c468408>
X_2 = poisson.rvs(X_true_2, size=X_true_2.shape)
params = [[5, 5, 5, 5],
[10, 10, 10, 10],
[1, 1, 1, 1]]
ssvi_store_k3 = [0 for i in range(len(params))]
dssvi_store_k3 = [0 for i in range(len(params))]
# Run simulations using given parameters
for num in range(len(params)):
print("Run: " + str(num))
param = params[num]
ssvi_store_k3[num] = ssvi_org.SSMF_BP_NMF(n_components=3, burn_in=20, random_state=10, verbose=False,
max_iter=50, cutoff=1e-2, X_true=X_true_2,
a=param[0], b=param[1], c=param[2], d=param[3])
dssvi_store_k3[num] = dssvi_func.SSMF_BP_NMF(n_components=3, burn_in=20, random_state=10, verbose=False,
max_iter=50, cutoff=1e-2, X_true=X_true_2,
a=param[0], b=param[1], c=param[2], d=param[3])
print("\t ssvi...")
ssvi_store_k3[num].fit(X_2)
print("\t dssvi...")
dssvi_store_k3[num].fit(X_2)
Run: 0 ssvi... dssvi... Run: 1 ssvi... dssvi... Run: 2 ssvi... dssvi...
RRMSE_RlogMSE(dssvi_store_k3, ssvi_store_k3)
plt.figure(figsize=(8,5))
plt.plot(dssvi_store_k3[0].time_list,
[(dssvi_store_k3[0].metrics_list[i][2]) for i in range(len(dssvi_store_k3[0].metrics_list))],
color="b", label="DSSVI")
plt.plot(ssvi_store_k3[0].time_list,
[(ssvi_store_k3[0].metrics_list[i][2]) for i in range(len(ssvi_store_k3[0].metrics_list))],
color="r", label="SSVI")
for j in range(1, len(dssvi_store)):
plt.plot(dssvi_store_k3[j].time_list,
[(dssvi_store_k3[j].metrics_list[i][2]) for i in range(len(dssvi_store_k3[j].metrics_list))], color="b")
plt.plot(ssvi_store_k3[j].time_list,
[(ssvi_store_k3[j].metrics_list[i][2]) for i in range(len(ssvi_store_k3[j].metrics_list))], color="r")
plt.legend()
plt.xlabel("time (sec)", fontsize=16)
plt.ylabel("RRMSE", fontsize=16)
plt.title("RRMSE over time", fontsize=16)
plt.show()
plt.figure(figsize=(8,5))
plt.plot(dssvi_store_k3[0].time_list,
[(dssvi_store_k3[0].metrics_list[i][3]) for i in range(len(dssvi_store_k3[0].metrics_list))], color="b", label="DSSVI")
plt.plot(ssvi_store_k3[0].time_list,
[(ssvi_store_k3[0].metrics_list[i][3]) for i in range(len(ssvi_store_k3[0].metrics_list))], color="r", label="SSVI")
for j in range(1, len(dssvi_store)):
plt.plot(dssvi_store_k3[j].time_list,
[(dssvi_store_k3[j].metrics_list[i][3]) for i in range(len(dssvi_store_k3[j].metrics_list))], color="b")
plt.plot(ssvi_store_k3[j].time_list,
[(ssvi_store_k3[j].metrics_list[i][3]) for i in range(len(ssvi_store_k3[j].metrics_list))], color="r")
plt.legend()
plt.xlabel("time (sec)", fontsize=16)
plt.ylabel("R_log_MSE", fontsize=16)
plt.ylim(0,2)
plt.title("R_log_MSE over time", fontsize=16)
plt.show()
show_results(dssvi_store_k3[1], X_true=X_true, dssvi=True)
[-23771.58110102 -15854.4261871 -15404.07453374 -15172.69342163 -14997.33489435 -14520.44828289 -14086.41641204 -13792.76545374 -13497.28902362 -13344.9658087 -13089.37094964 -12873.07086586 -12792.98396211 -12649.82059806 -12573.88166387 -12439.47826089 -12380.94729956 -12248.9270735 -12238.80563891 -12161.07147686 -12136.43758237 -12137.6556848 -12100.04261497 -12097.73036191 -12063.49130479 -11956.65864659 -11957.7633186 -11919.31630974 -11905.0657699 -11863.57269586 -11854.96197786 -11832.89331776 -11821.76379097 -11786.48525525 -11769.53900117 -11741.56803694 -11734.54823289 -11715.3495428 -11728.27944667 -11706.89050251 -11721.28277421 -11716.82284923 -11695.14395959 -11696.24122469 -11688.45668374 -11686.02697632 -11689.79972932 -11691.61811415 -11667.17639117 -11695.09719975]
show_results(ssvi_store_k3[1], X_true=X_true, dssvi=False)
[-25577.97083781 -15734.91605102 -15554.14241692 -15459.89438228 -15441.88495387 -15408.97409785 -15439.18818764 -15407.86628549 -15370.64609978 -15349.15866243 -15274.43170167 -15231.84420108 -15173.07775379 -15049.64960182 -14854.92512164 -14679.83597013 -14423.5056339 -14236.67362875 -14070.45252172 -13807.90780734 -13747.06533853 -13556.9375978 -13370.86870523 -13307.33559183 -13255.66839166 -13100.4874542 -13045.15162861 -12994.84315017 -12909.54229129 -12882.03523254 -12914.10017658 -12828.87687149 -12815.22987413 -12811.25687692 -12755.02582848 -12740.38420695 -12744.23626453 -12777.35039427 -12754.30424883 -12759.22789377 -12731.61403917 -12661.63917789 -12697.31080144 -12733.8124612 -12697.57393298 -12730.5919241 -12738.76477444 -12704.31128799 -12692.14758253 -12692.28589407]