python - 使用keras的句子相似度

标签 python keras sentence-similarity

我正在尝试基于此 work using the STS dataset 实现句子相似度架构.标签是从 0 到 1 的归一化相似性分数,因此假设它是一个回归模型。

我的问题是,从第一个纪元开始,损失直接进入 NaN。我做错了什么?

我已经尝试更新到最新的 keras 和 theano 版本。

我的模型的代码是:

def create_lstm_nn(input_dim):
    seq = Sequential()`
    # embedd using pretrained 300d embedding
    seq.add(Embedding(vocab_size, emb_dim, mask_zero=True, weights=[embedding_weights]))
    # encode via LSTM
    seq.add(LSTM(128))
    seq.add(Dropout(0.3))
    return seq

lstm_nn = create_lstm_nn(input_dim)

input_a = Input(shape=(input_dim,))
input_b = Input(shape=(input_dim,))

processed_a = lstm_nn(input_a)
processed_b = lstm_nn(input_b)

cos_distance = merge([processed_a, processed_b], mode='cos', dot_axes=1)
cos_distance = Reshape((1,))(cos_distance)
distance = Lambda(lambda x: 1-x)(cos_distance)

model = Model(input=[input_a, input_b], output=distance)

# train
rms = RMSprop()
model.compile(loss='mse', optimizer=rms)
model.fit([X1, X2], y, validation_split=0.3, batch_size=128, nb_epoch=20)

我也尝试使用简单的 Lambda 代替 Merge 层,但结果相同。

def cosine_distance(vests):
    x, y = vests
    x = K.l2_normalize(x, axis=-1)
    y = K.l2_normalize(y, axis=-1)
    return -K.mean(x * y, axis=-1, keepdims=True)

def cos_dist_output_shape(shapes):
    shape1, shape2 = shapes
    return (shape1[0],1)

distance = Lambda(cosine_distance, output_shape=cos_dist_output_shape)([processed_a, processed_b])

最佳答案

nan 是深度学习回归中的常见问题。因为你使用的是暹罗网络,你可以尝试以下操作:

  1. 检查您的数据:它们是否需要标准化?
  2. 尝试在你的网络中添加一个 Dense 层作为最后一层,但要小心选择激活函数,例如relu
  3. 尝试使用其他损失函数,例如对比损失
  4. 降低学习率,例如0.0001
  5. cos模式没有仔细处理被零除,可能是NaN的原因

让深度学习完美运行并不容易。

关于python - 使用keras的句子相似度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39289050/

相关文章:

python - 属性错误 : 'list' object has no attribute 'rfind' using petsc4py

python - 将字符串存储到json对象python

python - 继承列表 : Creating division by other lists, 整数和 float

python-3.x - 如何使用opencv python查找图像中的最小矩形?

python - 如何在 Keras 中使用 Hausdorff 度量?

python - 如何计算两个 n-gram 之间的语义相似度?

python - 在 Keras 中使用通用句子编码器嵌入层

Python Mayavi : Adding points to a 3d scatter plot

python - keras CNN 相同的输出

sentence-similarity - 如何在训练后在本地保存 SetFit 训练器