numpy - 带有浮点 numpy 数组的 tensorflow 记录

标签 numpy tensorflow

我想创建 tensorflow 记录来为我的模型提供数据;
到目前为止,我使用以下代码将 uint8 numpy 数组存储为 TFRecord 格式;

def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _floats_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def convert_to_record(name, image, label, map):
    filename = os.path.join(params.TRAINING_RECORDS_DATA_DIR, name + '.' + params.DATA_EXT)

    writer = tf.python_io.TFRecordWriter(filename)

    image_raw = image.tostring()
    map_raw   = map.tostring()
    label_raw = label.tostring()

    example = tf.train.Example(features=tf.train.Features(feature={
        'image_raw': _bytes_feature(image_raw),
        'map_raw': _bytes_feature(map_raw),
        'label_raw': _bytes_feature(label_raw)
    }))        
    writer.write(example.SerializeToString())
    writer.close()

我用这个示例代码阅读的
features = tf.parse_single_example(example, features={
  'image_raw': tf.FixedLenFeature([], tf.string),
  'map_raw': tf.FixedLenFeature([], tf.string),
  'label_raw': tf.FixedLenFeature([], tf.string),
})

image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape(params.IMAGE_HEIGHT*params.IMAGE_WIDTH*3)
image = tf.reshape(image_, (params.IMAGE_HEIGHT,params.IMAGE_WIDTH,3))

map = tf.decode_raw(features['map_raw'], tf.uint8)
map.set_shape(params.MAP_HEIGHT*params.MAP_WIDTH*params.MAP_DEPTH)
map = tf.reshape(map, (params.MAP_HEIGHT,params.MAP_WIDTH,params.MAP_DEPTH))

label = tf.decode_raw(features['label_raw'], tf.uint8)
label.set_shape(params.NUM_CLASSES)

这工作正常。现在我想对我的数组“map”做同样的事情,它是一个浮点 numpy 数组,而不是 uint8,我找不到如何做的例子;
我尝试了函数 _floats_feature,如果我将标量传递给它,它会起作用,但不适用于数组;
使用 uint8 可以通过 tostring() 方法完成序列化;

如何序列化一个 float numpy 数组以及如何读取它?

最佳答案

FloatListBytesList期待一个迭代。所以你需要向它传递一个浮点数列表。删除 _float_feature 中多余的括号, IE

def _floats_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

numpy_arr = np.ones((3,)).astype(np.float)
example = tf.train.Example(features=tf.train.Features(feature={"bytes": _floats_feature(numpy_arr)}))
print(example)

features {
  feature {
    key: "bytes"
    value {
      float_list {
        value: 1.0
        value: 1.0
        value: 1.0
      }
    }
  }
}

关于numpy - 带有浮点 numpy 数组的 tensorflow 记录,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41246438/

相关文章:

python - 使用 cv.fromarray 将转置的 NumPy 数组转换为 Python 中的 CvMat 类型

python - reshape 3D 数组中的一组数组

python - 如何将不同长度的时间窗应用于 Pandas 数据框

tensorflow - CNN 等神经网络中的求和和串联有什么区别?

java - 如何更改加载的 tensorflow 模型中的设备配置

python - 带有 TensorFlow 的推荐系统 (SVD)

python - 使用 LAPACK/BLAS 安装 numpy 的最简单方法是什么?

python - 具有相同索引的两个数据帧的 Pandas 外积

python - 无法从tensorflow.keras.metrics导入指标

python - ValueError:检查目标时出错:预期 main_prediction 有 3 个维度,但得到形状为 (1128, 1) 的数组