python - 如何 reshape tflearn 的输入数据?

标签 python pandas numpy tflearn

我正在尝试 reshape 我的数据以与 tflearn 兼容,数据中的每一行的形状为 (1300, 13)。因此,在加载数据并将每个 (1300, 13) 形状的行放入 numpy 数组中后,如下所示:

data_path = os.path.dirname(os.path.realpath(__file__)) + '/../data/data.csv'
train = data.sample(frac=0.8, random_state=200)
test = data.drop(train.index)

train_x = train['lowLevel.mfcc'].as_matrix()
test_x = test['lowLevel.mfcc'].as_matrix()

print(train_x.shape) # (8,)
print(train_x[0].shape) # (1300, 13)


train_y = to_categorical(train['category'], len(categories))
test_y = to_categorical(test['category'], len(categories))

train_x = train_x.reshape([-1, 1300, 13, 1])
test_x = test_x.reshape([-1, 1300, 13, 1])

# ValueError: cannot reshape array of size 8 into shape (1300,13,1)

不知道在这里做什么,我正在复制 MNIST tutorial来自文档: 他们的数据分别是形状 train_x train_y test_x test_y (55000, 10) (55000, 10) (10000, 784) (10000, 10)

我的数据形状是这样的(在我让它工作之前只加载 10 行): (8,) (8, 1) (2,) (2, 1) 当我打印 train_x 时,它看起来像这样: enter image description here

不确定所有数组的情况如何,因为我告诉 Pandas 将列作为矩阵加载...

MNIST 数据可以像这样完美地 reshape :

train_x, train_y, test_x, test_y = mnist.load_data(one_hot=True)

train_x = train_x.reshape([-1, 28, 28, 1])
test_x = test_x.reshape([-1, 28, 28, 1])

我正在从 pandas 数据帧加载数据,但不知道如何将其塑造成这样。

我在 tflearn 中设置了输入层,如下所示:

import tflearn
from tflearn.layers.core import input_data
from tflearn.data_utils import to_categorical

net = input_data(shape=[None, 1300, 13, 1], name='input')

有人知道发生了什么事吗?

最佳答案

弄清楚了,必须预先分配数组:

train_x = np.empty((train['lowLevel.mfcc'].size, 1300, 13))
test_x = np.empty((test['lowLevel.mfcc'].size, 1300, 13))

for index, item in enumerate(train['lowLevel.mfcc']):
    train_x[index] = item

for index, item in enumerate(test['lowLevel.mfcc']):
    test_x[index] = item

关于python - 如何 reshape tflearn 的输入数据?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50040428/

相关文章:

python - 将列表值分配给字典键

python - 在 pandas 系列上使用 apply 方法获取 TypeError 'Series' 对象是可变的,因此无法对其进行哈希处理

python - 使用 python 在 numpy 数组中加载 tiff 堆栈

python - Pandas:从返回数据创建索引时间序列 [从 100 开始]

python - 通过 Post 表单下载文件

python - 不同 GPU 上的 Tensorflow 执行和内存

python - Numpy修改数组到位?

Python Pandas Dataframe idxmax 太慢了。备择方案?

python - Pandas 在 Python 3 中从安全的 FTP 服务器读取数据

python - 带有 numpy 的范围数组