python - CNTK:从用于多 GPU 训练的 numpy 数组创建 MinibatchSource

标签 python numpy deep-learning cntk

我在 numpy 数组中有我的预处理图像数据,我的脚本在单个 GPU 上工作正常 feeding numpy array .据我了解,我们需要创建 MinibatchSource用于多 GPU 训练。我正在检查此示例 ( ConvNet_CIFAR10_DataAug_Distributed.py ) 以进行分布式训练,但它使用 *_map.txt,它基本上是图像文件(例如 png)的路径列表。我想知道最好的方法是从 numpy 数组创建 MinibatchSource,而不是将 numpy 数组转换回 png 文件。

最佳答案

您可以创建复合读取器,将多个图像反序列化器组合到一个源中。首先,您需要创建两个 map 文件(带有虚拟标签)。一个将包含所有输入图像,另一个将包含相应的目标图像。以下代码是一个最小的实现,假设文件名为 map1.txtmap2.txt

import numpy as np
import cntk as C
import cntk.io.transforms as xforms 
import sys

def create_reader(map_file1, map_file2):
    transforms = [xforms.scale(width=224, height=224, channels=3, interpolations='linear')]
    source1 = C.io.ImageDeserializer(map_file1, C.io.StreamDefs(
        source_image = C.io.StreamDef(field='image', transforms=transforms)))
    source2 = C.io.ImageDeserializer(map_file2, C.io.StreamDefs(
        target_image = C.io.StreamDef(field='image', transforms=transforms)))
    return C.io.MinibatchSource([source1, source2], max_samples=sys.maxsize, randomize=True)

x = C.input_variable((3,224,224))
y = C.input_variable((3,224,224))
# world's simplest model
model = C.layers.Convolution((3,3),3, pad=True)
z = model(x)
loss = C.squared_error(z, y)

reader = create_reader("map1.txt", "map2.txt")
trainer = C.Trainer(z, loss, C.sgd(z.parameters, C.learning_rate_schedule(.00001, C.UnitType.minibatch)))

minibatch_size = 2

input_map={
    x: reader.streams.source_image,
    y: reader.streams.target_image
}

for i in range(30):
    data=reader.next_minibatch(minibatch_size, input_map=input_map)
    print(data)
    trainer.train_minibatch(data)

关于python - CNTK:从用于多 GPU 训练的 numpy 数组创建 MinibatchSource,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43151213/

相关文章:

python - 在 Python3/Numpy 中过滤数组并返回索引

python - C 中的 Numpy 数组

python - keras如何使用ReduceLROnPlateau

python - 了解 Keras 中语音识别的 CTC 损失

python - 什么时候在keras的源代码中调用Layer.build()?

android - Android上的Python 'invalid syntax'错误

python - python(和编程)新手需要一些关于对角线排列的建议

python - 如何旋转数据框?

python - 为什么将 DatetimeIndex 转换为 np.array 时格式会发生变化?

python - 在 Keras 模型中删除然后插入一个新的中间层