对于以下数据框数据:
x y a b c
2 6 12 1 2
1 2 4 6 8
我想要新列(即 d)中的结果,它返回仅在 a、b、c 中具有最大值的列的名称。
cols
a
c
我试图从三列中找到最大值并返回列名。但是我不想选择数据集的所有行,而是只想选择这三列的行。我使用以下代码:
def returncolname(row, colnames):
return colnames[np.argmax(row.values)]
data['colmax'] = data.apply(lambda x: returncolname(x, data.columns), axis=1)
我能想到的最快的解决方案是DataFrame.dot
:
df.eq(df.max(1), axis=0).dot(df.columns)
详情
首先,计算每行的最大值:
df.max(1)
0 12
1 8
dtype: int64
接下来,找到这些值来自的位置:
df.eq(df.max(1), axis=0)
x y a b c
0 False False True False False
1 False False False False True
我使用 eq
来确保比较在各列中正确传播。
接下来,计算列列表的点积:
df.eq(df.max(1), axis=0).dot(df.columns)
0 a
1 c
dtype: object
如果最大值不唯一,使用
df.eq(df.max(1), axis=0).dot(df.columns + ',').str.rstrip(',')
获取以逗号分隔的列列表。例如,
更改几个值:
df.at[0, 'c'] = 12
df.at[1, 'y'] = 8
一切都是一样的,但请注意我在每一列后面都附加了一个逗号:
df.columns + ','
Index(['x,', 'y,', 'a,', 'b,', 'c,'], dtype='object')
df.eq(df.max(1), axis=0).dot(df.columns + ',')
0 a,c,
1 y,c,
dtype: object
从这里,去掉任何尾随的逗号:
df.eq(df.max(1), axis=0).dot(df.columns + ',').str.rstrip(',')
0 a,c
1 y,c
dtype: object