python - 创建一个输出字典的 tensorflow 数据集

标签 python dictionary tensorflow tensorflow-datasets

我的数据集有一个带有“元数据”的字典 {'m1': array_1, 'm2': array_2, ...}。每个数组的形状都是 (N, ...),其中 N 是样本数。

问题: 是否可以创建 tf.data.Dataset为数据集 iterator.get_next() 的每次迭代输出字典 {'meta_1': sub_array_1, 'meta_2': sub_array_2, ...}?此处,sub_array_i 应包含一个批处理的第 i 个元数据,因此形状应为 (batch_sz, ...)。

到目前为止我尝试使用的是 tf.data.Dataset.from_generator() ,像这样:

N = 100
# dictionary of arrays:
metadata = {'m1': np.zeros(shape=(N,2)), 'm2': np.ones(shape=(N,3,5))} 
num_samples = N

def meta_dict_gen():
    for i in range(num_samples):
        ls = {}
        for key, val in metadata.items():
            ls[key] = val[i]
        yield ls

dataset = tf.data.Dataset.from_generator(meta_dict_gen, output_types=(dict))

这个问题似乎在 output_types=(dict) 中。上面的代码向我抛出一个

TypeError: Expected DataType for argument 'Tout' not < class 'dict'>.


我正在使用 tensorflow 1.8 和 python 3.6。

最佳答案

所以实际上可以做你想做的,你只需要具体说明字典的内容:

import tensorflow as tf
import numpy as np

N = 100
# dictionary of arrays:
metadata = {'m1': np.zeros(shape=(N,2)), 'm2': np.ones(shape=(N,3,5))}
num_samples = N

def meta_dict_gen():
    for i in range(num_samples):
        ls = {}
        for key, val in metadata.items():
            ls[key] = val[i]
        yield ls

dataset = tf.data.Dataset.from_generator(
    meta_dict_gen,
    output_types={k: tf.float32 for k in metadata},
    output_shapes={'m1': (2,), 'm2': (3, 5)})
iter = dataset.make_one_shot_iterator()
next_elem = iter.get_next()
print(next_elem)

输出:

{'m1': <tf.Tensor 'IteratorGetNext:0' shape=(2,) dtype=float32>,
 'm2': <tf.Tensor 'IteratorGetNext:1' shape=(3, 5) dtype=float32>}

关于python - 创建一个输出字典的 tensorflow 数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51136862/

相关文章:

python - 使用 Python 驱动程序连接到 Dockerized Clickhouse Server 时出现问题

python - 比较两个字典,如果一个字典中存在键/值对,则删除另一个字典中的键/值对

objective-c - 将字典项从 plist 转换为 NSArray

python - 计算与 tensorflow 运算并集的平均交集而不显式调用更新运算?

python - 使用 Python-Requests 库发布文本文件

Python:使用 numpy 数组时避免内存错误的替代方法?

python - 在 python 中读取缺少项目的列

c++ - 类中的 map<string, string>

python - 与 Tensorflow 2.0 同一层的不同尺寸过滤器

python - 在c++中嵌入python时导入tensorflow返回null