import dssvi_func
import ssvi_org
import time
import numpy as np
import matplotlib.pyplot as plt
import random
import pandas as pd
from scipy.stats import poisson
EPS = np.spacing(1)
def plot_results(dssvi_obj, ssvi_obj):
plt.plot(dssvi_obj.time_list,
[(dssvi_obj.metrics_list[i][2]) for i in range(len(dssvi_obj.metrics_list))], label="DSSVI")
plt.plot(ssvi_obj.time_list,
[(ssvi_obj.metrics_list[i][2]) for i in range(len(ssvi_obj.metrics_list))], label="SSVI")
plt.legend()
plt.xlabel("time")
plt.ylabel("RRMSE")
plt.title("RRMSE across time")
plt.ylim(0, 1)
plt.show()
plt.plot(dssvi_obj.time_list,
[(dssvi_obj.metrics_list[i][3]) for i in range(len(dssvi_obj.metrics_list))], label="DSSVI")
plt.plot(ssvi_obj.time_list,
[(ssvi_obj.metrics_list[i][3]) for i in range(len(ssvi_obj.metrics_list))], label="SSVI")
plt.legend()
plt.xlabel("time")
plt.ylabel("Log RMSE")
plt.title("Log RMSE across time")
plt.ylim(0, 2)
plt.show()
return
def show_results(obj, X_true, dssvi=True):
# very small vaule, for log(0).
EPS = np.spacing(1)
if (dssvi == True):
# show # of K and log_ll plot
dssvi_func.K_and_log_ll(obj)
# plot Exp. pi
dssvi_func.plot_pi(obj, true_pi_W=true_pi_W, true_pi_H=true_pi_H)
# Calc. posterior means
W_post_mean, H_post_mean, X_post_mean = dssvi_func.get_post_means(obj)
# draw matrices
dssvi_func.draw_two_matrices(W_S_W, W_post_mean * obj.S_W[:, obj.good_k], 1, 1, 'true W * S_W', 'E(W) * S_W')
dssvi_func.draw_two_matrices(H_S_H, H_post_mean * obj.S_H[obj.good_k] , 1, 1, 'true H * S_H', 'E(H) * S_H')
dssvi_func.draw_two_matrices(X_true, X_post_mean, 1, 1, first='True X', second='Reconstructed X')
# draw log
X_post_mean = pd.DataFrame.from_records(X_post_mean)
X_true = pd.DataFrame.from_records(X_true)
dssvi_func.draw_two_matrices(np.log(X_true.replace(0, np.nan)).replace(np.nan, 0),
np.log(X_post_mean.replace(0, np.nan)).replace(np.nan, 0), 1, 1,
'log(X true)', 'log(X_reconst) (0s not taken log)')
#dssvi_func.draw_two_matrices(np.log(X_true + EPS), np.log(X_post_mean + EPS), 1, 1, 'log(True X)', 'log(Reconst. X)')
else:
# show # of K and log_ll plot
ssvi_org.K_and_log_ll(obj)
# plot Exp. pi
ssvi_org.plot_pi(obj, true_pi_H=true_pi_H)
# Calc. posterior means
W_post_mean, H_post_mean, X_post_mean = ssvi_org.get_post_means(obj)
# draw matrices
ssvi_org.draw_two_matrices(W_S_W, W_post_mean, 1, 1, 'true W * S_W', 'E(W)')
ssvi_org.draw_two_matrices(H_S_H, H_post_mean * obj.S[obj.good_k] , 1, 1, 'true H * S_H', 'E(H) * S_H')
ssvi_org.draw_two_matrices(X_true, X_post_mean, 1, 1, first='True X', second='Reconstructed X')
X_post_mean = pd.DataFrame.from_records(X_post_mean)
X_true = pd.DataFrame.from_records(X_true)
ssvi_org.draw_two_matrices(np.log(X_true.replace(0, np.nan)).replace(np.nan, 0),
np.log(X_post_mean.replace(0, np.nan)).replace(np.nan, 0), 1, 1,
'log(X true)', 'log(X_reconst) (0s not taken log)')
#ssvi_org.draw_two_matrices(np.log(X_true + EPS), np.log(X_post_mean + EPS), 1, 1, 'log(True X)', 'log(Reconst. X)')
return
def RRMSE_RlogMSE(dssvi_store, ssvi_store):
fig, ax = plt.subplots(len(dssvi_store), 2, figsize=(14,14))
for i in range(len(dssvi_store)):
for j in range(2):
ax[i, j].plot(dssvi_store[i].time_list,
[(dssvi_store[i].metrics_list[k][j+2]) for k in range(len(dssvi_store[i].metrics_list))], label="DSSVI")
ax[i, j].plot(ssvi_store[i].time_list,
[(ssvi_store[i].metrics_list[k][j+2]) for k in range(len(ssvi_store[i].metrics_list))], label="SSVI")
if j == 0:
ax[i, j].set_title(f"RRMSE, run={i+1}", fontsize=14)
ax[i,j].set_ylim([0,1])
else:
ax[i,j].set_title(f"RlogMSE, run={i+1}", fontsize=14)
ax[i,j].set_ylim([0,4])
ax[i,j].set_xlabel("time (sec)", fontsize=14)
ax[i,j].legend()
fig.tight_layout(pad=3.0)
plt.show()
return
true_pi_W = [0.3, 0.2, 0.3, 1]
true_pi_H = [0.4, 0.5, 0.45, 1]
S_W = np.vstack((np.hstack((np.ones(30), np.zeros(70))),
np.hstack((np.zeros(20), np.ones(20), np.zeros(60))),
np.hstack((np.zeros(50), np.ones(30), np.zeros(20))),
np.hstack((np.ones(100))))).T
S_H = np.vstack((np.hstack((np.ones(40), np.zeros(60))),
np.hstack((np.zeros(20), np.ones(50), np.zeros(30))),
np.hstack((np.zeros(55), np.ones(45))),
np.hstack(np.ones(100))))
S_W = S_W.repeat(6,axis=0)
dssvi_func.show_true_S(S_W, S_H)
np.random.seed(10)
cluster1 = np.random.gamma(1.6, 1, 600) # entries of cluster 1 of W and H each have mean 8
cluster2 = np.random.gamma(1.4, 1, 600) # entries of cluster 2 of W and H have mean 5
cluster3 = np.random.gamma(1.2, 1, 600) # entries of cluster 3 of W and H have mean 2
dense = np.random.gamma(1,1,10000)
clusters = [1,1,1,1]
clusters[0], clusters[1], clusters[2], clusters[3] = cluster1, cluster2, cluster3, dense
# populate W where entries for S_W is not 0
W_S_W = np.zeros((S_W.shape))
for i in range(S_W.shape[0]):
for j in range(S_W.shape[1]):
if S_W[i][j] != 0:
W_S_W[i][j] = clusters[j][i]
# simiarly for H_S_H
H_S_H = np.zeros((S_H.shape))
for i in range(S_H.shape[0]):
for j in range(S_H.shape[1]):
if S_H[i][j] != 0:
H_S_H[i][j] = clusters[i][j]
def show_true_combined():
plt.figure(figsize = (4,5))
plt.imshow(W_S_W, aspect="auto", interpolation="none")
plt.colorbar()
plt.figure(figsize = (6,3))
plt.imshow(H_S_H, aspect="auto", interpolation="none")
plt.colorbar()
show_true_combined()
X_true_k4 = W_S_W.dot(H_S_H)
plt.figure(figsize = (5,5))
plt.imshow(X_true_k4, aspect="auto", interpolation="none")
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x23006e50a08>
X_true_log = np.log(X_true_k4+EPS)
dssvi_func.draw_two_matrices(X_true_k4, np.log(X_true_k4+EPS), 1, 1)
X_1_k4 = poisson.rvs(X_true_k4, size=X_true_k4.shape)
X_1_k4
array([[18, 7, 2, ..., 0, 0, 0], [ 9, 5, 5, ..., 0, 0, 0], [ 6, 7, 2, ..., 0, 0, 0], ..., [ 0, 0, 0, ..., 0, 0, 0], [ 0, 0, 0, ..., 0, 0, 0], [ 0, 0, 0, ..., 0, 0, 0]])
params = [4, 5, 6] #number of k's
ssvi_store = [0 for i in range(len(params))]
dssvi_store = [0 for i in range(len(params))]
# Run simulations using given parameters
for num in range(len(params)):
print("Run: " + str(num))
param = params[num]
ssvi_store[num] = ssvi_org.SSMF_BP_NMF(n_components=param, burn_in=30, random_state=10, verbose=False,
max_iter=50, cutoff=1e-2, X_true=X_true_k4)
dssvi_store[num] = dssvi_func.SSMF_BP_NMF(n_components=param, burn_in=30, random_state=10, verbose=False,
max_iter=50, cutoff=1e-2, X_true=X_true_k4)
print("\t ssvi...")
ssvi_store[num].fit(X_1_k4)
print("\t dssvi...")
dssvi_store[num].fit(X_1_k4)
Run: 0 ssvi... dssvi... Run: 1 ssvi... dssvi... Run: 2 ssvi... dssvi...
RRMSE_RlogMSE(dssvi_store, ssvi_store)
show_results(dssvi_store[0], X_true=X_true_k4, dssvi=True)
[ -inf -92584.60733129 -88301.82138469 -86598.1515705 -85492.42214018 -84592.42860797 -83774.54636649 -82877.68221064 -81936.91651757 -81211.69560256 -80578.63858483 -79839.24038689 -79344.27027801 -78756.62365166 -78308.08922013 -77815.80560249 -77370.36199023 -76987.84938648 -76498.39731705 -76099.52705807 -75673.74022625 -75287.96778998 -75177.13531224 -74912.57827451 -74659.15443839 -74427.21266601 -74317.49994059 -74089.72024779 -73940.55236573 -73840.75216724 -73649.93130908 -73453.39992792 -73413.69291633 -73273.26752471 -73221.60306787 -73050.09183232 -73020.6466285 -72845.97114001 -72882.87866653 -72836.7162221 -72737.61755564 -72615.38462916 -72549.00062353 -72404.54224874 -72402.5275705 -72415.42124403 -72351.9853129 -72165.32394687 -72189.09098025 -72170.5856766 ]
show_results(ssvi_store[0], X_true=X_true_k4, dssvi=False)
[-132019.63010399 -96119.88195495 -95175.42129265 -93464.99292838 -92315.50016741 -91000.74272602 -89808.00371411 -88804.78109632 -88128.50381547 -87237.70118624 -86658.34796922 -86124.37442101 -85516.3282071 -84876.0923982 -84520.14109551 -84013.30131254 -83557.06256209 -83181.97077467 -82773.9190232 -82284.48929329 -81957.71678067 -81553.25465493 -81180.3503292 -80659.6417874 -80281.03515183 -80019.63796929 -79748.92018435 -79541.99310103 -79313.9926883 -79098.47612167 -78912.96356889 -78653.19737813 -78456.05295314 -78118.29258412 -77854.12341108 -77727.95717188 -77491.61662367 -77249.23836002 -77078.74876726 -76919.90436868 -76636.92791472 -76403.17340725 -76235.50570695 -75977.26733837 -75787.60749926 -75596.37565042 -75297.2377661 -75161.5018069 -75008.55614366 -74875.49768602]
plt.figure(figsize=(8,5))
plt.plot(dssvi_store[0].time_list,
[(dssvi_store[0].metrics_list[i][2]) for i in range(len(dssvi_store[0].metrics_list))], color="b", label="DSSVI")
plt.plot(ssvi_store[0].time_list,
[(ssvi_store[0].metrics_list[i][2]) for i in range(len(ssvi_store[0].metrics_list))], color="r", label="SSVI")
for j in range(1, len(dssvi_store)):
plt.plot(dssvi_store[j].time_list,
[(dssvi_store[j].metrics_list[i][2]) for i in range(len(dssvi_store[j].metrics_list))], color="b")
plt.plot(ssvi_store[j].time_list,
[(ssvi_store[j].metrics_list[i][2]) for i in range(len(ssvi_store[j].metrics_list))], color="r")
plt.legend()
plt.xlabel("time (sec)", fontsize=16)
plt.ylabel("RRMSE", fontsize=16)
plt.title("RRMSE over time", fontsize=16)
plt.show()
plt.figure(figsize=(8,5))
plt.plot(dssvi_store[0].time_list,
[(dssvi_store[0].metrics_list[i][3]) for i in range(len(dssvi_store[0].metrics_list))], color="b", label="DSSVI")
plt.plot(ssvi_store[0].time_list,
[(ssvi_store[0].metrics_list[i][3]) for i in range(len(ssvi_store[0].metrics_list))], color="r", label="SSVI")
for j in range(1, len(dssvi_store)):
plt.plot(dssvi_store[j].time_list,
[(dssvi_store[j].metrics_list[i][3]) for i in range(len(dssvi_store[j].metrics_list))], color="b")
plt.plot(ssvi_store[j].time_list,
[(ssvi_store[j].metrics_list[i][3]) for i in range(len(ssvi_store[j].metrics_list))], color="r")
plt.legend()
plt.xlabel("time (sec)", fontsize=16)
plt.ylabel("R_log_MSE", fontsize=16)
plt.title("R_log_MSE over time", fontsize=16)
plt.show()
true_pi_W = [0.3, 0.4, 0.5]
true_pi_H = [0.4, 0.5, 0.45]
S_W = np.vstack((np.hstack((np.ones(30), np.zeros(70))),
np.hstack((np.zeros(20), np.ones(40), np.zeros(40))),
np.hstack((np.zeros(50), np.ones(50))))).T
S_H = np.vstack((np.hstack((np.ones(40), np.zeros(60))),
np.hstack((np.zeros(20), np.ones(50), np.zeros(30))),
np.hstack((np.zeros(55), np.ones(45)))))
dssvi_func.show_true_S(S_W, S_H)
np.random.seed(10)
cluster1 = np.random.gamma(2, 1, 50) # entries of cluster 1 of W and H each have mean 8
cluster2 = np.random.gamma(1.5, 1, 50) # entries of cluster 2 of W and H have mean 5
cluster3 = np.random.gamma(1, 1, 50) # entries of cluster 3 of W and H have mean 2
W_S_W = np.zeros((S_W.shape))
W_S_W[:30, 0] = S_W[:30, 0] * cluster1[:30]
W_S_W[20:60, 1] = S_W[20:60, 1] * cluster2[:40]
W_S_W[50:, 2] = S_W[50:, 2] * cluster3[:50]
H_S_H = np.zeros((S_H.shape))
H_S_H[0, :40] = S_H[0, :40] * cluster1[:40]
H_S_H[1, 20:70] = S_H[1, 20:70] * cluster2[:50]
H_S_H[2, 55:] = S_H[2, 55:] * cluster3[:45]
show_true_combined()
X_true_k3 = np.dot(W_S_W, H_S_H)
X_1_k3 = poisson.rvs(X_true_k3, size=X_true_k3.shape)
dssvi_func.draw_two_matrices(X_true_k3, X_1_k3, 1, 1)
params = [3,4,5] # number of k's
ssvi_store_k3 = [0 for i in range(len(params))]
dssvi_store_k3 = [0 for i in range(len(params))]
# Run simulations using given parameters
for num in range(len(params)):
print("Run: " + str(num))
param = params[num]
ssvi_store_k3[num] = ssvi_org.SSMF_BP_NMF(n_components=param, burn_in=30, random_state=10, verbose=False,
max_iter=50, cutoff=1e-2, X_true=X_true_k3)
dssvi_store_k3[num] = dssvi_func.SSMF_BP_NMF(n_components=param, burn_in=30, random_state=10, verbose=False,
max_iter=50, cutoff=1e-2, X_true=X_true_k3)
print("\t ssvi...")
ssvi_store_k3[num].fit(X_1_k3)
print("\t dssvi...")
dssvi_store_k3[num].fit(X_1_k3)
Run: 0 ssvi... dssvi... Run: 1 ssvi... dssvi... Run: 2 ssvi... dssvi...
RRMSE_RlogMSE(dssvi_store_k3, ssvi_store_k3)
plt.figure(figsize=(8,5))
plt.plot(dssvi_store_k3[0].time_list,
[(dssvi_store_k3[0].metrics_list[i][2]) for i in range(len(dssvi_store_k3[0].metrics_list))],
color="b", label="DSSVI")
plt.plot(ssvi_store_k3[0].time_list,
[(ssvi_store_k3[0].metrics_list[i][2]) for i in range(len(ssvi_store_k3[0].metrics_list))],
color="r", label="SSVI")
for j in range(1, len(dssvi_store)):
plt.plot(dssvi_store_k3[j].time_list,
[(dssvi_store_k3[j].metrics_list[i][2]) for i in range(len(dssvi_store_k3[j].metrics_list))], color="b")
plt.plot(ssvi_store_k3[j].time_list,
[(ssvi_store_k3[j].metrics_list[i][2]) for i in range(len(ssvi_store_k3[j].metrics_list))], color="r")
plt.legend()
plt.xlabel("time (sec)", fontsize=16)
plt.ylabel("RRMSE", fontsize=16)
plt.title("RRMSE over time", fontsize=16)
plt.show()
plt.figure(figsize=(8,5))
plt.plot(dssvi_store_k3[0].time_list,
[(dssvi_store_k3[0].metrics_list[i][3]) for i in range(len(dssvi_store_k3[0].metrics_list))], color="b", label="DSSVI")
plt.plot(ssvi_store_k3[0].time_list,
[(ssvi_store_k3[0].metrics_list[i][3]) for i in range(len(ssvi_store_k3[0].metrics_list))], color="r", label="SSVI")
for j in range(1, len(dssvi_store)):
plt.plot(dssvi_store_k3[j].time_list,
[(dssvi_store_k3[j].metrics_list[i][3]) for i in range(len(dssvi_store_k3[j].metrics_list))], color="b")
plt.plot(ssvi_store_k3[j].time_list,
[(ssvi_store_k3[j].metrics_list[i][3]) for i in range(len(ssvi_store_k3[j].metrics_list))], color="r")
plt.legend()
plt.xlabel("time (sec)", fontsize=16)
plt.ylabel("R_log_MSE", fontsize=16)
plt.ylim(0,2)
plt.title("R_log_MSE over time", fontsize=16)
plt.show()
dssvi_store_k3[1].good_k
array([0, 1, 2, 3])
show_results(dssvi_store_k3[1], X_true=X_true_k3, dssvi=True)
[ -inf -13604.60485974 -12790.85636547 -12032.61171757 -11437.80139388 -10882.10264773 -10612.74434041 -10285.3805076 -10137.42431117 -9889.1253092 -9504.87351073 -9222.12921773 -9020.9332091 -8787.63096083 -8617.61528936 -8359.62356095 -8182.77031951 -7991.77773177 -7792.6405216 -7714.22968637 -7666.7501047 -7652.29030906 -7615.50987914 -7588.55507461 -7578.5283936 -7555.15464635 -7554.03404646 -7555.33694682 -7516.36653106 -7510.61280757 -7512.85406598 -7512.78086638 -7507.98170867 -7481.80656165 -7486.33266383 -7471.36643723 -7490.29125839 -7448.05978184 -7438.60795542 -7471.70849009 -7406.66954697 -7425.09855646 -7456.31646419 -7406.6348898 -7400.32794619 -7378.37114961 -7370.91900161 -7413.36166217 -7413.7162522 -7438.56533502]
show_results(ssvi_store_k3[1], X_true=X_true_k3, dssvi=False)
[-21501.73106156 -15178.60342157 -14626.64575988 -14070.06900712 -13656.06206499 -13451.73241347 -13192.80572429 -13076.65997796 -12883.63213438 -12678.83364566 -12529.39871013 -12349.80518325 -12112.95513491 -12022.38096588 -11897.92540898 -11671.04291139 -11612.97027664 -11408.25199737 -11158.60936342 -10976.39312333 -10782.49809554 -10623.70328257 -10379.3647695 -10333.63848675 -10103.35191792 -9973.04793034 -9887.47699638 -9788.69074795 -9649.91921522 -9423.10353703 -9287.25105565 -9167.86965266 -8991.66276525 -8970.83673928 -8901.44954391 -8839.12027008 -8727.1543283 -8760.53426286 -8644.80751668 -8580.7245133 -8538.48986397 -8488.93393705 -8420.09660612 -8374.89734418 -8306.9709317 -8230.42591887 -8165.63176492 -8066.20880628 -7963.12058519 -7890.94856798]