python - Pandas groupby & linregress 如何提取

标签 python pandas

我正在按组对数据框进行线性回归以生成汇总统计数据。我已经使用 scipy linregress 计算了两个变量 km vs price 的回归:

import pandas as pd
from scipy.stats import linregress    
df = pd.read_csv('test dataset faceted small.csv')
grouped = df.groupby(['year','make','engine','drive','transmission','badge'])
test = grouped.apply(lambda x: linregress(x['km'], x['price']))
print test
test.to_csv('grouped.csv', index=False)

打印测试给了我:
year  make    engine  drive  transmission  badge                
1994  subaru  1.6L    awd    auto          wrx                      (-0.0019029525668, 2217.67284738, -0.190381626...
1997  mazda   1.3L    2wd    manual        121 metro                (-0.00724142957301, 4213.71579612, -0.30608491...
1999  nissan  1.6L    2wd    auto          pulsar plus lx n15 s2    (-0.00245336355614, 3653.42015515, -0.17060101...

保存到 csv 的测试是:
LinregressResult(slope=-0.0019029525667976811, intercept=2217.6728473825792, rvalue=-0.19038162624636565, pvalue=4.2750387135904842e-07, stderr=0.00037275167083276965)
LinregressResult(slope=-0.0072414295730094738, intercept=4213.7157961188113, rvalue=-0.30608491681348643, pvalue=4.8781453623746113e-17, stderr=0.00084171437048465665)
LinregressResult(slope=-0.0024533635561369252, intercept=3653.4201551461483, rvalue=-0.17060101350197393, pvalue=1.4676330869804576e-07, stderr=0.0004631573671617427)

但是我想要的 csv 输出是:
year  make    engine  drive  transmission  badge                   slope             intercept       rvalue      
1994  subaru  1.6L    awd    auto          wrx                     -0.0019029525668  2217.67284738 -0.190381626...
1997  mazda   1.3L    2wd    manual        121 metro               -0.00724142957301 4213.71579612 -0.30608491...
1999  nissan  1.6L    2wd    auto          pulsar plus lx n15 s2   -0.00245336355614 3653.42015515 -0.17060101...

以便我以后可以轻松调用结果。如何将 LinregressResult 附加到每个组并将它们保存到 csv?

最佳答案

我想你可以简单地这样做:

test = (grouped.apply(lambda x: pd.Series(linregress(x['km'], x['price'])))
               .rename(columns={
                        0: 'slope',
                        1: 'intercept',
                        2: 'rvalue',
                        3: 'pvalue',
                        4: 'stderr'
                      })
       )

代替
test = grouped.apply(lambda x: linregress(x['km'], x['price']))

示范:
rows = 10

# generate random integer numbers
df = pd.DataFrame(np.random.randint(0, 10, size=(rows, 5)), columns=list('abcde'))

def linregress(x):
    # imitates `linregress`
    # returns tuples 
    return tuple(x)

test = (df.apply(lambda x: pd.Series(linregress(x)), axis=1)
          .rename(columns={
                   0: 'slope',
                   1: 'intercept',
                   2: 'rvalue',
                   3: 'pvalue',
                   4: 'stderr'
                 })
       )

输出:
In [48]: df.apply(lambda x: linregress(x), axis=1)
Out[48]:
0    (7, 7, 2, 0, 0)
1    (6, 9, 3, 1, 5)
2    (5, 1, 6, 1, 3)
3    (4, 4, 2, 1, 4)
4    (8, 7, 1, 5, 4)
5    (0, 2, 7, 6, 1)
6    (3, 8, 4, 2, 8)
7    (6, 0, 0, 3, 2)
8    (9, 4, 6, 2, 3)
9    (8, 1, 7, 9, 8)
dtype: object


In [50]: test = (df.apply(lambda x: pd.Series(linregress(x)), axis=1)
   ....:           .rename(columns={
   ....:                    0: 'slope',
   ....:                    1: 'intercept',
   ....:                    2: 'rvalue',
   ....:                    3: 'pvalue',
   ....:                    4: 'stderr'
   ....:                  })
   ....:        )

In [51]: test
Out[51]:
   slope  intercept  rvalue  pvalue  stderr
0      7          7       2       0       0
1      6          9       3       1       5
2      5          1       6       1       3
3      4          4       2       1       4
4      8          7       1       5       4
5      0          2       7       6       1
6      3          8       4       2       8
7      6          0       0       3       2
8      9          4       6       2       3
9      8          1       7       9       8

关于python - Pandas groupby & linregress 如何提取,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37084446/

相关文章:

python - 每 7 行中,获取第 n 行 pandas

python-3.x - 过滤并获取数据框中条件之间的行

python - 在 shell 中运行 python 的脚本

python - 查找最近 30 分钟内 DataFrame 中的元素数

python - 在 Spyder IDE 上运行 Pyinstaller(或 Anaconda3 提示符)

python - 在python中读取大数据

python - pandas.concat : Cannot handle a non-unique multi-index! Pandas Python

python - Pandas - 取消堆叠到顶部列级别

python - 将字典映射到数据框无法正常工作

python - Python 3.x 中字符串的内部表示是什么