python - keras LSTM 获取隐藏状态(将句子序列转换为文档上下文向量)

标签 python keras lstm embedding bert-language-model

我尝试使用 keras 通过 LSTM 从句子向量创建文档上下文向量(因此每个文档都包含一系列句子向量)。

我的目标是使用 keras 复制以下博客文章:https://andriymulyar.com/blog/bert-document-classification

我有一个(玩具)张量,看起来像这样:X = np.array(features).reshape(5, 200, 768) 所以 5 个文档,每个文档都有 200 个序列句子向量 - 每个句子向量有 768 个特征。

因此,为了从我的句子向量中获得嵌入,我将文档编码为 one-hot-vectors 以学习 LSTM:

y = [1,2,3,4,5] # 5 documents in toy-tensor
y = np.array(y)
yy = to_categorical(y)
yy = yy[0:5,1:6]

到目前为止,我的代码看起来像这样

inputs1=Input(shape=(200,768))
lstm1, states_h, states_c =LSTM(5,dropout=0.3,recurrent_dropout=0.2, return_state=True)(inputs1)
model1=Model(inputs1,lstm1)
model1.compile(loss='categorical_crossentropy',optimizer='rmsprop',metrics=['acc']) 
model1.summary()
model1.fit(x=X,y=yy,batch_size=100,epochs=10,verbose=1,shuffle=True,validation_split=0.2)

当我打印 states_h 时,我得到一个 shape=(?,5) 的张量,但我真的不知道如何访问张量内的向量,这些向量应该代表我的文档。

print(states_h)
Tensor("lstm_51/while/Exit_3:0", shape=(?, 5), dtype=float32)

还是我做错了什么?据我了解应该有 5 个文档向量,例如doc1=[...] ; ...; doc5=[...] 以便我可以重用文档向量来执行分类任务。

最佳答案

好吧,打印张量准确地显示了这一点:它是一个张量,它具有那种形状和那种类型。

如果你想看到数据,你需要提供数据。
状态不是权重,它们不是持久的,它们仅与输入数据一起存在,就像任何其他模型输出一样。

您应该创建一个输出此信息的模型(您的模型没有)以便获取它。您可以有两个模型:

#this is the model you compile and train - exactly as you are already doing
training_model = Model(inputs1,lstm1)     

#this is just for getting the states, nothing else, don't compile, don't train
state_getting_model = Model(inputs1, [lstm1, states_h, states_c]) 

(不用担心,即使您只训练training_model,这两个模型也会共享相同的权重并一起更新)

现在您可以:

关闭急切模式(也可能“打开”):

lstm_out, states_h_out, states_c_out = state_getting_model.predict(X)
print(states_h_out)
print(states_c_out)

开启 Eager 模式:

lstm_out, states_h_out, states_c_out = state_getting_model(X)
print(states_h_out.numpy())
print(states_c_out.numpy())

关于python - keras LSTM 获取隐藏状态(将句子序列转换为文档上下文向量),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59498423/

相关文章:

python - 将有状态 LSTM 称为功能模型?

tensorflow - 类型错误 : __init__() got multiple values for argument 'kernel_size'

python - lstm 维度与 tensorflow 不匹配

python - 使用 matplotlib.ticker 设置 Matplotlib Ytick 标签会产生关键错误

python - asyncio.wait_for 不会传播 CancelledError,如果在取消之前等待 future 为 "done"

keras - 是否可以从 Keras 中的 flow_from_directory 自动推断出 class_weight ?

python - 为 LSTM 模型调用预测函数时出现有关输入形状的错误

python - 在具有 Tensorflow 后端的 Keras 上,在不同的输入部分上并行拟合 LSTM 和一些密集层

python - 计算列表中第二个元素的频率? (Python)

python - 在 BeautifulSoup4 的 findAll 中包含多个类名