python-3.x - 批量读取Cifar10数据集

标签 python-3.x machine-learning computer-vision batch-processing

我正在尝试读取 CIFAR10 数据集,从 https://www.cs.toronto.edu/~kriz/cifar.html 中批量给出>。我正在尝试使用 pickle 将其放入数据框中并读取其中的“数据”部分。但我收到此错误。

KeyError                                  Traceback (most recent call last)
<ipython-input-24-8758b7a31925> in <module>()
----> 1 unpickle('datasets/cifar-10-batches-py/test_batch')

<ipython-input-23-04002b89d842> in unpickle(file)
      3     fo = open(file, 'rb')
      4     dict = pickle.load(fo, encoding ='bytes')
----> 5     X = dict['data']
      6     fo.close()
      7     return dict

key 错误:“数据”。

我正在使用 ipython,这是我的代码:

def unpickle(file):

 fo = open(file, 'rb')
 dict = pickle.load(fo, encoding ='bytes')
 X = dict['data']
 fo.close()
 return dict

unpickle('datasets/cifar-10-batches-py/test_batch')

最佳答案

您可以通过下面给出的代码读取 cifar 10 数据集,只需确保您提供的是放置批处理的写入目录

import tensorflow as tf
import pandas as pd
import numpy as np
import math
import timeit
import matplotlib.pyplot as plt
from six.moves import cPickle as pickle
import os
import platform
from subprocess import check_output
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

%matplotlib inline


img_rows, img_cols = 32, 32
input_shape = (img_rows, img_cols, 3)
def load_pickle(f):
    version = platform.python_version_tuple()
    if version[0] == '2':
        return  pickle.load(f)
    elif version[0] == '3':
        return  pickle.load(f, encoding='latin1')
    raise ValueError("invalid python version: {}".format(version))

def load_CIFAR_batch(filename):
    """ load single batch of cifar """
    with open(filename, 'rb') as f:
        datadict = load_pickle(f)
        X = datadict['data']
        Y = datadict['labels']
        X = X.reshape(10000,3072)
        Y = np.array(Y)
        return X, Y

def load_CIFAR10(ROOT):
    """ load all of cifar """
    xs = []
    ys = []
    for b in range(1,6):
        f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
        X, Y = load_CIFAR_batch(f)
        xs.append(X)
        ys.append(Y)
    Xtr = np.concatenate(xs)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
    return Xtr, Ytr, Xte, Yte
def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=10000):
    # Load the raw CIFAR-10 data
    cifar10_dir = '../input/cifar-10-batches-py/'
    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)

    # Subsample the data
    mask = range(num_training, num_training + num_validation)
    X_val = X_train[mask]
    y_val = y_train[mask]
    mask = range(num_training)
    X_train = X_train[mask]
    y_train = y_train[mask]
    mask = range(num_test)
    X_test = X_test[mask]
    y_test = y_test[mask]

    x_train = X_train.astype('float32')
    x_test = X_test.astype('float32')

    x_train /= 255
    x_test /= 255

    return x_train, y_train, X_val, y_val, x_test, y_test


# Invoke the above function to get our data.
x_train, y_train, x_val, y_val, x_test, y_test = get_CIFAR10_data()


print('Train data shape: ', x_train.shape)
print('Train labels shape: ', y_train.shape)
print('Validation data shape: ', x_val.shape)
print('Validation labels shape: ', y_val.shape)
print('Test data shape: ', x_test.shape)
print('Test labels shape: ', y_test.shape)

关于python-3.x - 批量读取Cifar10数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37512290/

相关文章:

android - Android 运行时的 Tensorflow 人脸识别

opencv - 自适应阈值中窗口大小的选择

machine-learning - 卷积神经网络 (CNN) 的训练和验证准确性突然下降

python - 与 python 3.1 和 py-postgresql 兼容的 web 框架

python-3.x - 通过 gspread 对电子表格进行排序

r - 使用现有数据和概率模拟数据

opencv - cv2 霍夫线的霍夫空间

python - Django - 在views.py内部使用我自己的REST API的正确方法?

python - 如何设置python中字符的二进制形式显示的字节数?

python - 如何在 PyTorch 中构建具有两个输入的网络