import dssvi_func
import ssvi_org
import time
import numpy as np
import matplotlib.pyplot as plt
import random
import pandas as pd
from scipy.stats import poisson
def plot_results(dssvi_obj, ssvi_obj):
plt.plot(dssvi_obj.time_list,
[(dssvi_obj.metrics_list[i][2]) for i in range(len(dssvi_obj.metrics_list))], label="DSSVI")
plt.plot(ssvi_obj.time_list,
[(ssvi_obj.metrics_list[i][2]) for i in range(len(ssvi_obj.metrics_list))], label="SSVI")
plt.legend()
plt.xlabel("time")
plt.ylabel("RRMSE")
plt.title("RRMSE across time")
plt.ylim(0, 1)
plt.show()
plt.plot(dssvi_obj.time_list,
[(dssvi_obj.metrics_list[i][3]) for i in range(len(dssvi_obj.metrics_list))], label="DSSVI")
plt.plot(ssvi_obj.time_list,
[(ssvi_obj.metrics_list[i][3]) for i in range(len(ssvi_obj.metrics_list))], label="SSVI")
plt.legend()
plt.xlabel("time")
plt.ylabel("Log RMSE")
plt.title("Log RMSE across time")
plt.ylim(0, 2)
plt.show()
return
def show_results(obj, X_true, dssvi=True):
# very small vaule, for log(0).
EPS = np.spacing(1)
if (dssvi == True):
# show # of K and log_ll plot
dssvi_func.K_and_log_ll(obj)
# plot Exp. pi
dssvi_func.plot_pi(obj, true_pi_W=true_pi_W, true_pi_H=true_pi_H)
# Calc. posterior means
W_post_mean, H_post_mean, X_post_mean = dssvi_func.get_post_means(obj)
# draw matrices
dssvi_func.draw_two_matrices(W_S_W, W_post_mean * obj.S_W[:, obj.good_k], 1, 1, 'true W * S_W', 'E(W) * S_W')
dssvi_func.draw_two_matrices(H_S_H, H_post_mean * obj.S_H[obj.good_k] , 1, 1, 'true H * S_H', 'E(H) * S_H')
dssvi_func.draw_two_matrices(X_true, X_post_mean, 1, 1, first='True X', second='Reconstructed X')
# draw log
X_post_mean = pd.DataFrame.from_records(X_post_mean)
X_true = pd.DataFrame.from_records(X_true)
dssvi_func.draw_two_matrices(np.log(X_true.replace(0, np.nan)).replace(np.nan, 0),
np.log(X_post_mean.replace(0, np.nan)).replace(np.nan, 0), 1, 1,
'log(X true)', 'log(X_reconst) (0s not taken log)')
#dssvi_func.draw_two_matrices(np.log(X_true + EPS), np.log(X_post_mean + EPS), 1, 1, 'log(True X)', 'log(Reconst. X)')
else:
# show # of K and log_ll plot
ssvi_org.K_and_log_ll(obj)
# plot Exp. pi
ssvi_org.plot_pi(obj, true_pi_H=true_pi_H)
# Calc. posterior means
W_post_mean, H_post_mean, X_post_mean = ssvi_org.get_post_means(obj)
# draw matrices
ssvi_org.draw_two_matrices(W_S_W, W_post_mean, 1, 1, 'true W * S_W', 'E(W)')
ssvi_org.draw_two_matrices(H_S_H, H_post_mean * obj.S[obj.good_k] , 1, 1, 'true H * S_H', 'E(H) * S_H')
ssvi_org.draw_two_matrices(X_true, X_post_mean, 1, 1, first='True X', second='Reconstructed X')
X_post_mean = pd.DataFrame.from_records(X_post_mean)
X_true = pd.DataFrame.from_records(X_true)
ssvi_org.draw_two_matrices(np.log(X_true.replace(0, np.nan)).replace(np.nan, 0),
np.log(X_post_mean.replace(0, np.nan)).replace(np.nan, 0), 1, 1,
'log(X true)', 'log(X_reconst) (0s not taken log)')
#ssvi_org.draw_two_matrices(np.log(X_true + EPS), np.log(X_post_mean + EPS), 1, 1, 'log(True X)', 'log(Reconst. X)')
return
def RRMSE_RlogMSE(dssvi_store, ssvi_store):
fig, ax = plt.subplots(len(dssvi_store), 2, figsize=(14,14))
for i in range(len(dssvi_store)):
for j in range(2):
ax[i, j].plot(dssvi_store[i].time_list,
[(dssvi_store[i].metrics_list[k][j+2]) for k in range(len(dssvi_store[i].metrics_list))], label="DSSVI")
ax[i, j].plot(ssvi_store[i].time_list,
[(ssvi_store[i].metrics_list[k][j+2]) for k in range(len(ssvi_store[i].metrics_list))], label="SSVI")
if j == 0:
ax[i, j].set_title(f"RRMSE, run={i+1}", fontsize=14)
ax[i,j].set_ylim([0,1])
else:
ax[i,j].set_title(f"RlogMSE, run={i+1}", fontsize=14)
ax[i,j].set_ylim([0,4])
ax[i,j].set_xlabel("time (sec)", fontsize=14)
ax[i,j].legend()
fig.tight_layout(pad=3.0)
plt.show()
return
true_pi_W = [0.8, 0.9, 0.8, 1]
true_pi_H = [0.4, 0.3, 0.3, 1]
S_W = np.vstack((np.hstack((np.ones(80), np.zeros(20))),
np.hstack((np.zeros(10), np.ones(90))),
np.hstack((np.ones(80), np.zeros(20))),
np.hstack((np.ones(100))))).T
S_H = np.vstack((np.hstack((np.ones(40), np.zeros(60))),
np.hstack((np.zeros(30), np.ones(30), np.zeros(40))),
np.hstack((np.zeros(70), np.ones(30))),
np.hstack(np.ones(100))))
dssvi_func.show_true_S(S_W, S_H)
np.random.seed(10)
cluster1 = np.random.gamma(1.6, 1, 600) # entries of cluster 1 of W and H each have mean 8
cluster2 = np.random.gamma(1.4, 1, 600) # entries of cluster 2 of W and H have mean 5
cluster3 = np.random.gamma(1.2, 1, 600) # entries of cluster 3 of W and H have mean 2
dense = np.random.gamma(1,1,10000)
clusters = [1,1,1,1]
clusters[0], clusters[1], clusters[2], clusters[3] = cluster1, cluster2, cluster3, dense
# populate W where entries for S_W is not 0
W_S_W = np.zeros((S_W.shape))
for i in range(S_W.shape[0]):
for j in range(S_W.shape[1]):
if S_W[i][j] != 0:
W_S_W[i][j] = clusters[j][i]
# simiarly for H_S_H
H_S_H = np.zeros((S_H.shape))
for i in range(S_H.shape[0]):
for j in range(S_H.shape[1]):
if S_H[i][j] != 0:
H_S_H[i][j] = clusters[i][j]
def show_true_combined():
plt.figure(figsize = (4,5))
plt.imshow(W_S_W, aspect="auto", interpolation="none")
plt.colorbar()
plt.figure(figsize = (6,3))
plt.imshow(H_S_H, aspect="auto", interpolation="none")
plt.colorbar()
show_true_combined()
X_true_k4 = W_S_W.dot(H_S_H)
plt.figure(figsize = (5,5))
plt.imshow(X_true_k4, aspect="auto", interpolation="none")
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x25f729b2888>
X_true_k4 = pd.DataFrame.from_records(X_true_k4)
X_true_log = np.log(X_true_k4.replace(0, np.nan)).replace(np.nan, 0)
dssvi_func.draw_two_matrices(X_true_k4, X_true_log, 1, 1, 'X_true', "log(X_true) (0's not taken log)")
X_1 = poisson.rvs(X_true, size=X_true.shape)
params = [4, 5, 6] #number of k's
ssvi_store = [0 for i in range(len(params))]
dssvi_store = [0 for i in range(len(params))]
# Run simulations using given parameters
for num in range(len(params)):
print("Run: " + str(num))
param = params[num]
ssvi_store[num] = ssvi_org.SSMF_BP_NMF(n_components=param, burn_in=30, random_state=10, verbose=False,
max_iter=50, cutoff=1e-2, X_true=X_true)
dssvi_store[num] = dssvi_func.SSMF_BP_NMF(n_components=param, burn_in=30, random_state=10, verbose=False,
max_iter=50, cutoff=1e-2, X_true=X_true)
print("\t ssvi...")
ssvi_store[num].fit(X_1)
print("\t dssvi...")
dssvi_store[num].fit(X_1)
Run: 0 ssvi... dssvi... Run: 1 ssvi... dssvi... Run: 2 ssvi... dssvi...
RRMSE_RlogMSE(dssvi_store, ssvi_store)
show_results(dssvi_store[0], X_true=X_true_k4, dssvi=True)
[-25522.57132897 -19726.63423165 -18876.91473156 -18393.38412282 -18079.18417316 -17801.82589949 -17591.00948364 -17440.5780344 -17298.72405712 -17217.61438159 -17116.90760513 -17050.29816299 -17000.62905421 -16885.89438975 -16822.62962466 -16757.34718092 -16705.79034669 -16674.98366874 -16648.11949025 -16610.45011534 -16520.98051036 -16563.14842601 -16499.68270953 -16443.89957558 -16429.95910382 -16427.52968467 -16410.06515043 -16327.17850375 -16309.26581189 -16288.80794426 -16293.26713613 -16253.93612885 -16266.64523313 -16233.63771589 -16247.67297462 -16191.52194656 -16256.24580881 -16185.33206991 -16146.89600118 -16147.46851692 -16132.13151 -16169.40898864 -16189.72965814 -16141.42123752 -16135.35454216 -16146.11814686 -16117.24430957 -16118.02024608 -16096.5134814 -16081.16494167]
show_results(ssvi_store[0], X_true=X_true_k4, dssvi=False)
[-27110.01061978 -20148.47159768 -19742.01144767 -19540.4491911 -19419.47224964 -19278.98893305 -19165.06100956 -18969.04346315 -18778.84204376 -18616.22283575 -18459.4737587 -18378.64081091 -18264.91406975 -18193.8133312 -18098.87725899 -18049.59923067 -17982.00463566 -17907.47660735 -17876.34283888 -17828.07274381 -17762.57000736 -17736.90337419 -17675.67953992 -17626.77648024 -17629.22083809 -17557.15128726 -17512.25675009 -17532.2171025 -17462.53562723 -17436.46483677 -17373.57214146 -17355.87924769 -17352.80271961 -17292.34475186 -17279.34517559 -17224.48953733 -17236.84795684 -17200.93843866 -17182.31771022 -17206.51648786 -17179.39293832 -17116.57009669 -17107.19576258 -17064.07119704 -17081.99543333 -17109.42808511 -17036.69137153 -17036.10420657 -16994.32180705 -16988.93140715]
plt.figure(figsize=(8,5))
plt.plot(dssvi_store[0].time_list,
[(dssvi_store[0].metrics_list[i][2]) for i in range(len(dssvi_store[0].metrics_list))], color="b", label="DSSVI")
plt.plot(ssvi_store[0].time_list,
[(ssvi_store[0].metrics_list[i][2]) for i in range(len(ssvi_store[0].metrics_list))], color="r", label="SSVI")
for j in range(1, len(dssvi_store)):
plt.plot(dssvi_store[j].time_list,
[(dssvi_store[j].metrics_list[i][2]) for i in range(len(dssvi_store[j].metrics_list))], color="b")
plt.plot(ssvi_store[j].time_list,
[(ssvi_store[j].metrics_list[i][2]) for i in range(len(ssvi_store[j].metrics_list))], color="r")
plt.legend()
plt.xlabel("time (sec)", fontsize=16)
plt.ylabel("RRMSE", fontsize=16)
plt.title("RRMSE over time", fontsize=16)
plt.show()
plt.figure(figsize=(8,5))
plt.plot(dssvi_store[0].time_list,
[(dssvi_store[0].metrics_list[i][3]) for i in range(len(dssvi_store[0].metrics_list))], color="b", label="DSSVI")
plt.plot(ssvi_store[0].time_list,
[(ssvi_store[0].metrics_list[i][3]) for i in range(len(ssvi_store[0].metrics_list))], color="r", label="SSVI")
for j in range(1, len(dssvi_store)):
plt.plot(dssvi_store[j].time_list,
[(dssvi_store[j].metrics_list[i][3]) for i in range(len(dssvi_store[j].metrics_list))], color="b")
plt.plot(ssvi_store[j].time_list,
[(ssvi_store[j].metrics_list[i][3]) for i in range(len(ssvi_store[j].metrics_list))], color="r")
plt.legend()
plt.xlabel("time (sec)", fontsize=16)
plt.ylabel("R_log_MSE", fontsize=16)
plt.title("R_log_MSE over time", fontsize=16)
plt.show()
true_pi_W = [0.8, 0.9, 0.8]
true_pi_H = [0.4, 0.3, 0.3]
S_W = np.vstack((np.hstack((np.ones(80), np.zeros(20))),
np.hstack((np.zeros(10), np.ones(90))),
np.hstack((np.ones(80), np.zeros(20))))).T
S_H = np.vstack((np.hstack((np.ones(40), np.zeros(60))),
np.hstack((np.zeros(30), np.ones(30), np.zeros(40))),
np.hstack((np.zeros(70), np.ones(30)))))
dssvi_func.show_true_S(S_W, S_H)
np.random.seed(10)
cluster1 = np.random.gamma(1.6, 1, 600) # entries of cluster 1 of W and H each have mean 8
cluster2 = np.random.gamma(1.4, 1, 600) # entries of cluster 2 of W and H have mean 5
cluster3 = np.random.gamma(1.2, 1, 600) # entries of cluster 3 of W and H have mean 2
clusters = [1,1,1]
clusters[0], clusters[1], clusters[2] = cluster1, cluster2, cluster3
# populate W where entries for S_W is not 0
W_S_W = np.zeros((S_W.shape))
for i in range(S_W.shape[0]):
for j in range(S_W.shape[1]):
if S_W[i][j] != 0:
W_S_W[i][j] = clusters[j][i]
# simiarly for H_S_H
H_S_H = np.zeros((S_H.shape))
for i in range(S_H.shape[0]):
for j in range(S_H.shape[1]):
if S_H[i][j] != 0:
H_S_H[i][j] = clusters[i][j]
show_true_combined()
X_true_2 = W_S_W.dot(H_S_H)
plt.figure(figsize = (5,5))
plt.imshow(X_true_2, aspect="auto", interpolation="none")
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x25f72151448>
X_2 = poisson.rvs(X_true_2, size=X_true.shape)
params = [3,4,5] # number of k's
ssvi_store_k3 = [0 for i in range(len(params))]
dssvi_store_k3 = [0 for i in range(len(params))]
# Run simulations using given parameters
for num in range(len(params)):
print("Run: " + str(num))
param = params[num]
ssvi_store_k3[num] = ssvi_org.SSMF_BP_NMF(n_components=param, burn_in=30, random_state=10, verbose=False,
max_iter=50, cutoff=1e-2, X_true=X_true_2)
dssvi_store_k3[num] = dssvi_func.SSMF_BP_NMF(n_components=param, burn_in=30, random_state=10, verbose=False,
max_iter=50, cutoff=1e-2, X_true=X_true_2)
print("\t ssvi...")
ssvi_store_k3[num].fit(X_2)
print("\t dssvi...")
dssvi_store_k3[num].fit(X_2)
Run: 0 ssvi... dssvi... Run: 1 ssvi... dssvi... Run: 2 ssvi... dssvi...
RRMSE_RlogMSE(dssvi_store_k3, ssvi_store_k3)
plt.figure(figsize=(8,5))
plt.plot(dssvi_store_k3[0].time_list,
[(dssvi_store_k3[0].metrics_list[i][2]) for i in range(len(dssvi_store_k3[0].metrics_list))],
color="b", label="DSSVI")
plt.plot(ssvi_store_k3[0].time_list,
[(ssvi_store_k3[0].metrics_list[i][2]) for i in range(len(ssvi_store_k3[0].metrics_list))],
color="r", label="SSVI")
for j in range(1, len(dssvi_store)):
plt.plot(dssvi_store_k3[j].time_list,
[(dssvi_store_k3[j].metrics_list[i][2]) for i in range(len(dssvi_store_k3[j].metrics_list))], color="b")
plt.plot(ssvi_store_k3[j].time_list,
[(ssvi_store_k3[j].metrics_list[i][2]) for i in range(len(ssvi_store_k3[j].metrics_list))], color="r")
plt.legend()
plt.xlabel("time (sec)", fontsize=16)
plt.ylabel("RRMSE", fontsize=16)
plt.title("RRMSE over time", fontsize=16)
plt.show()
plt.figure(figsize=(8,5))
plt.plot(dssvi_store_k3[0].time_list,
[(dssvi_store_k3[0].metrics_list[i][3]) for i in range(len(dssvi_store_k3[0].metrics_list))], color="b", label="DSSVI")
plt.plot(ssvi_store_k3[0].time_list,
[(ssvi_store_k3[0].metrics_list[i][3]) for i in range(len(ssvi_store_k3[0].metrics_list))], color="r", label="SSVI")
for j in range(1, len(dssvi_store)):
plt.plot(dssvi_store_k3[j].time_list,
[(dssvi_store_k3[j].metrics_list[i][3]) for i in range(len(dssvi_store_k3[j].metrics_list))], color="b")
plt.plot(ssvi_store_k3[j].time_list,
[(ssvi_store_k3[j].metrics_list[i][3]) for i in range(len(ssvi_store_k3[j].metrics_list))], color="r")
plt.legend()
plt.xlabel("time (sec)", fontsize=16)
plt.ylabel("R_log_MSE", fontsize=16)
plt.ylim(0,2)
plt.title("R_log_MSE over time", fontsize=16)
plt.show()
show_results(dssvi_store_k3[2], X_true=X_true_2, dssvi=True)
[-20281.15581056 -14615.00857998 -13094.63055279 -12576.9326782 -12355.23436137 -12225.51841504 -12102.90225163 -12007.61870286 -11946.58357307 -11851.16935219 -11826.41291637 -11746.98468501 -11711.22563204 -11620.7055886 -11594.94892436 -11538.3117039 -11517.14505252 -11504.51195945 -11432.57280567 -11402.40823238 -11366.3271555 -11305.50572406 -11291.89836053 -11234.74158855 -11246.78939902 -11208.25562965 -11206.8534479 -11172.29743815 -11160.74255736 -11126.92242927 -11154.73723704 -11057.50769446 -11096.40712796 -11059.37967177 -11063.19350937 -10988.53536287 -10992.9877148 -10988.83774082 -10966.36468268 -10958.35564381 -10970.04979593 -11016.67557781 -10986.73892563 -10942.29395087 -10918.04923229 -10918.3697848 -10909.7894691 -10927.54040753 -10914.67600467 -10879.77500753]
show_results(ssvi_store_k3[2], X_true=X_true_2, dssvi=False)
[ -inf -15266.7425035 -14802.42047887 -14230.33430126 -13756.1626744 -13424.9654907 -13149.52647593 -12966.16419434 -12774.92493745 -12635.76633309 -12561.97243876 -12412.88391427 -12352.19843257 -12253.79673488 -12169.15501013 -12151.63431173 -12037.29110315 -12010.24688554 -11946.89679843 -11874.88848861 -11834.92349096 -11825.64575133 -11765.66716437 -11747.4105137 -11736.83796434 -11657.14330268 -11665.45418286 -11649.6071363 -11581.02760795 -11596.67011882 -11585.35565126 -11574.68664231 -11502.27993684 -11499.75035619 -11470.11241266 -11451.4864661 -11471.8146118 -11430.94227415 -11386.56406472 -11425.55070582 -11384.73086875 -11348.33765309 -11377.89860768 -11345.94465949 -11332.62962519 -11293.76691984 -11269.97387427 -11280.73725884 -11206.86205667 -11224.06773017]