随着 tf.contrib 模块从 Tensorflow 中消失,并且 tf.train.Saver() 也消失了,我找不到存储一组嵌入及其相应缩略图的方法,以便 Tensorboard Projector 可以读取它们。
Tensorboard documentation Tensorflow 2.0 解释了如何创建绘图和摘要,以及如何使用一般的摘要工具,但没有关于投影仪工具。有没有人发现如何存储数据集以进行可视化?
如果可能的话,我会很感激一个(最小的)代码示例。
最佳答案
由于缺乏文档,似乎很多人都在使用 TF2.x 中的 Tensorboard Projector 时遇到问题。我设法让它工作,在这个 comment on GitHub我提供了一些最小的代码示例。我知道这些问题也与使用缩略图( Sprite )有关,但我不需要它并且想保持示例简单,所以让 Sprite 工作留给读者作为练习。
# Some initial code which is the same for all the variants
import os
import numpy as np
import tensorflow as tf
from tensorboard.plugins import projector
def register_embedding(embedding_tensor_name, meta_data_fname, log_dir):
config = projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = embedding_tensor_name
embedding.metadata_path = meta_data_fname
projector.visualize_embeddings(log_dir, config)
def get_random_data(shape=(100,100)):
x = np.random.rand(*shape)
y = np.random.randint(low=0, high=2, size=shape[0])
return x, y
def save_labels_tsv(labels, filepath, log_dir):
with open(os.path.join(log_dir, filepath), 'w') as f:
for label in labels:
f.write('{}\n'.format(label))
LOG_DIR = 'tmp' # Tensorboard log dir
META_DATA_FNAME = 'meta.tsv' # Labels will be stored here
EMBEDDINGS_TENSOR_NAME = 'embeddings'
EMBEDDINGS_FPATH = os.path.join(LOG_DIR, EMBEDDINGS_TENSOR_NAME + '.ckpt')
STEP = 0
x, y = get_random_data((100,100))
register_embedding(EMBEDDINGS_TENSOR_NAME, META_DATA_FNAME, LOG_DIR)
save_labels_tsv(y, META_DATA_FNAME, LOG_DIR)
VARIANT A(适用于 TF2.0 和 TF2.1,但不适用于 Eager 模式)
# Size of files created on disk: 163kB
tf.compat.v1.disable_eager_execution()
tensor_embeddings = tf.Variable(x, name=EMBEDDINGS_TENSOR_NAME)
sess = tf.compat.v1.InteractiveSession()
sess.run(tf.compat.v1.global_variables_initializer())
saver = tf.compat.v1.train.Saver()
saver.save(sess, EMBEDDINGS_FPATH, STEP)
sess.close()
VARIANT B(在 Eager 模式下同时在 TF2.0 和 TF2.1 中工作)
# Size of files created on disk: 80.5kB
tensor_embeddings = tf.Variable(x, name=EMBEDDINGS_TENSOR_NAME)
saver = tf.compat.v1.train.Saver([tensor_embeddings]) # Must pass list or dict
saver.save(sess=None, global_step=STEP, save_path=EMBEDDINGS_FPATH)
我要感谢其他开发人员从他们的 Stack 答案、GitHub 评论或个人博客文章中提供的一些代码,它们帮助我将这些示例放在一起。你是真正的MVP。
关于python-3.x - 如何在 Tensorflow 2.0 中使用嵌入投影仪,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57014236/