我需要创建一个神经网络层来优化 GeneralizedExtremeValue 分布的三个参数(即 loc、scale 和 < em>浓度),但是当多个参数传递到DistributionLambda层时,所有训练指标都是nan,并且输出分布为空。
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_probability as tfp
import tensorflow as tf
tfd = tfp.distributions
tfkl = tf.keras.layers
tfpl = tfp.layers
# INPUT DATASETS ==========================================================
# creating a GEV distribution
dist = tfd.GeneralizedExtremeValue(loc=2, scale=1, concentration=0.1)
# sampling the GEV to create a dataset
dataset = dist.sample(10**(5))
# creating some noise on the sample
def add_eps (values):
eps = np.random.randn(len(values))
return values + eps
x_train = (0.75* add_eps(dataset[:(8*10**4)]))
y_train = dataset[:(8*10**4)]
x_test = (0.75* add_eps(dataset[(8*10**4):]))
# plotting the input datasets
fig=plt.figure()
sns.histplot(dataset, bins=100,
stat='probability', kde=True,
color='r', label='Sample from GEV: y_train')
sns.histplot(x_train, bins=100,
stat='probability',
kde=True,
color='b', label='Modified sample from GEV: x_train')
plt.legend()
plt.show()
绘制x_train和y_train,输入数据集似乎是正确的。无论如何,神经网络输出的分布是 nan,而不是 mean 和 stddev。我尝试更改batch_size、epoch以及Dense层的输出数量,但没有任何效果。
# CREATING BAYESIAN NETWORK ================================================
# model architecture
model = tf.keras.Sequential([
tfkl.Dense(3, input_shape = (1,)),
tfpl.DistributionLambda(
lambda t: tfd.GeneralizedExtremeValue(loc=t[0],
scale=t[1], concentration=t[2]))
])
negloglik = lambda y_true, y_pred: -y_pred.log_prob(y_true)
model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.01),
loss=negloglik)
# Model Fitting
history = model.fit(x_train, y_train,
validation_split=0.2, epochs=100,
verbose=True, shuffle=True,
batch_size =300)
# predictions
y_pred_mean = model(x_test.numpy().reshape(-1,1)).mean()
y_pred_std = model(x_test.numpy().reshape(-1,1)).stddev()
最佳答案
尝试向参数添加一些约束。我对这种分布并不熟悉,但我很确定西格玛(比例)应该大于零,所以在你的情况下
scale = 1e-3 + tf.math.softplus(0.01 *t[1])
关于python - 我可以使用 DistributionLambda 层优化三个参数吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70347663/