我有 3D 图像(tiff)数据和文件夹内的每个卷。我想读取数据并为卷积网络制作批量张量。我可以将数据读取为 numpy 数组,但我不知道如何为 CNN 进行批量张量输入。这是我的代码
import os
import tensorflow as tf
import numpy as np
from skimage import io
from matplotlib import pyplot as plt
from pathlib import Path
data_dir = 'C:/Users/myname/Documents/Projects/Segmentation/DeepLearning/L-net/data/'
data_folders = os.listdir(data_dir)
train_input = []
train_output = []
test_input = []
test_output = []
for idx, folder in enumerate(data_folders):
im = io.imread(data_dir+folder+'/f0.tiff')
im = im/im.max()
train_input.append(tf.convert_to_tensor(im, dtype=tf.float32))
im = io.imread(data_dir+folder+'/g0.tiff')
im = im/im.max()
train_output.append(tf.convert_to_tensor(im, dtype=tf.float32))
由于我在 CNN 中使用 3D 滤波器,因此输入应该是 5D 张量。有人可以帮我弄这个吗?谢谢。
最佳答案
采用您的方法,您必须立即将所有数据加载到内存中,并且还必须处理所有维度。我建议使用 Keras flow_from_directory
和 generators
。 Keras 有这个 ImageDataGenerator 类,它允许用户从目录中执行图像收集,将所有图像更改为您想要的任何大小,对它们进行随机播放,...。您可以找到文档here在他们的网站上。
Download the train dataset and test dataset, extract them into 2 different folders named as “train” and “test”. The train folder should contain ‘n’ folders each containing images of respective classes. For example, In the Dog vs Cats data set, the train folder should have 2 folders, namely “Dog” and “Cats” containing respective images inside them.
这是有关如何为模型输入创建数据集的示例:
train_generator = train_datagen.flow_from_directory(
directory=r"C:/Users/myname/Documents/Projects/Segmentation/DeepLearning/L-net/data/",
target_size=(224, 224), # the size of your input images
color_mode="rgb", # could be grayscale or rgb
batch_size=32, # Number of images in each batsh
class_mode="categorical",
shuffle=True, # Whether to shuffle the images or not
seed=42 # Random seed for applying random image augmentation
)
您可以像这样进行训练:
STEP_SIZE_TRAIN=train_generator.n//train_generator.batch_size
STEP_SIZE_VALID=valid_generator.n//valid_generator.batch_size
model.fit_generator(generator=train_generator,
steps_per_epoch=STEP_SIZE_TRAIN,
validation_data=valid_generator,
validation_steps=STEP_SIZE_VALID,
epochs=10
)
关于python - TensorFlow:从 3D CNN 的 numpy 数组获取输入批处理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59184700/