python - Scikit-learn : Input contains NaN, 无穷大或 dtype ('float64' 的值太大)

标签 python numpy machine-learning scikit-learn

我正在使用 Python scikit-learn 对从 csv 获得的数据进行简单线性回归。

reader = pandas.io.parsers.read_csv("data/all-stocks-cleaned.csv")
stock = np.array(reader)

openingPrice = stock[:, 1]
closingPrice = stock[:, 5]

print((np.min(openingPrice)))
print((np.min(closingPrice)))
print((np.max(openingPrice)))
print((np.max(closingPrice)))

peningPriceTrain, openingPriceTest, closingPriceTrain, closingPriceTest = \
    train_test_split(openingPrice, closingPrice, test_size=0.25, random_state=42)


openingPriceTrain = np.reshape(openingPriceTrain,(openingPriceTrain.size,1))

openingPriceTrain = openingPriceTrain.astype(np.float64, copy=False)
# openingPriceTrain = np.arange(openingPriceTrain, dtype=np.float64)

closingPriceTrain = np.reshape(closingPriceTrain,(closingPriceTrain.size,1))
closingPriceTrain = closingPriceTrain.astype(np.float64, copy=False)

openingPriceTest = np.reshape(openingPriceTest,(openingPriceTest.size,1))
closingPriceTest = np.reshape(closingPriceTest,(closingPriceTest.size,1))

regression = linear_model.LinearRegression()

regression.fit(openingPriceTrain, closingPriceTrain)

predicted = regression.predict(openingPriceTest)

最小值和最大值显示为 0.0 0.6 41998.0 2593.9

但我收到此错误 ValueError: Input contains NaN, infinity or a value too large for dtype('float64').

我应该如何消除这个错误? 因为从上面的结果来看,它确实不包含无穷大或 Nan 值。

解决这个问题的方法是什么?

编辑:all-stocks-cleaned.csv 在 http://www.sharecsv.com/s/cb31790afc9b9e33c5919cdc562630f3/all-stocks-cleaned.csv 可用

最佳答案

您的回归问题在于 NaN 以某种方式潜入了您的数据。这可以使用以下代码片段轻松检查:

import pandas as pd
import numpy as np
from  sklearn import linear_model
from sklearn.cross_validation import train_test_split

reader = pd.io.parsers.read_csv("./data/all-stocks-cleaned.csv")
stock = np.array(reader)

openingPrice = stock[:, 1]
closingPrice = stock[:, 5]

openingPriceTrain, openingPriceTest, closingPriceTrain, closingPriceTest = \
    train_test_split(openingPrice, closingPrice, test_size=0.25, random_state=42)

openingPriceTrain = openingPriceTrain.reshape(openingPriceTrain.size,1)
openingPriceTrain = openingPriceTrain.astype(np.float64, copy=False)

closingPriceTrain = closingPriceTrain.reshape(closingPriceTrain.size,1)
closingPriceTrain = closingPriceTrain.astype(np.float64, copy=False)

openingPriceTest = openingPriceTest.reshape(openingPriceTest.size,1)
openingPriceTest = openingPriceTest.astype(np.float64, copy=False)

np.isnan(openingPriceTrain).any(), np.isnan(closingPriceTrain).any(), np.isnan(openingPriceTest).any()

(True, True, True)

如果您尝试像下面这样估算缺失值:

openingPriceTrain[np.isnan(openingPriceTrain)] = np.median(openingPriceTrain[~np.isnan(openingPriceTrain)])
closingPriceTrain[np.isnan(closingPriceTrain)] = np.median(closingPriceTrain[~np.isnan(closingPriceTrain)])
openingPriceTest[np.isnan(openingPriceTest)] = np.median(openingPriceTest[~np.isnan(openingPriceTest)])

您的回归将顺利进行,不会出现问题:

regression = linear_model.LinearRegression()

regression.fit(openingPriceTrain, closingPriceTrain)

predicted = regression.predict(openingPriceTest)

predicted[:5]

array([[ 13598.74748173],
       [ 53281.04442146],
       [ 18305.4272186 ],
       [ 50753.50958453],
       [ 14937.65782778]])

简而言之:如错误消息所述,您的数据中存在缺失值。

编辑::

也许更简单、更直接的方法是在使用 pandas 读取数据后立即检查是否有任何缺失的数据:

data = pd.read_csv('./data/all-stocks-cleaned.csv')
data.isnull().any()
Date                    False
Open                     True
High                     True
Low                      True
Last                     True
Close                    True
Total Trade Quantity     True
Turnover (Lacs)          True

然后用以下两行中的任何一行来估算数据:

data = data.fillna(lambda x: x.median())

data = data.fillna(method='ffill')

关于python - Scikit-learn : Input contains NaN, 无穷大或 dtype ('float64' 的值太大),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34779961/

相关文章:

python - 如何将 x、y 坐标编码为来自 Dart 的 python ecdsa VerifyingKey 格式

python - 相当于在 Tensorflow 中设置类似 Numpy 的掩码值?

machine-learning - 使用完整数据集进行预测是一个好习惯吗?

machine-learning - 反向传播,所有输出趋于1

matlab - matlab中的支持向量机

python - “uuid”是该函数的无效关键字参数

Python urllib2 HTTPBasicAuthHandler

python - Python 中比较字符串元素的最快方法

python - 如何在Python中使用高斯kde内核设置带宽来平滑线条

python - 使用 NumPy 将一个数组与另一个数组建立索引