正在努力解决searborn facetgrid 热图
缓慢的问题。我已经从之前的 problem 扩展了我的数据集并感谢@Diziet Asahi 提供了facetgrid 问题的解决方案。
现在,我有 20x20 网格,每个网格中有 625 个点要映射。即使是一层 little1
也需要很长时间才能获得输出。我的真实数据有数千个小
层。
我的代码是这样的:
import pandas as pd
import numpy as np
import itertools
import seaborn as sns
from matplotlib.colors import ListedColormap
print("seaborn version {}".format(sns.__version__))
# R expand.grid() function in Python
# https://stackoverflow.com/a/12131385/1135316
def expandgrid(*itrs):
product = list(itertools.product(*itrs))
return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}
ltt= ['little1']
methods=["m" + str(i) for i in range(1,21)]
labels=["l" + str(i) for i in range(1,20)]
times = range(0,100,4)
data = pd.DataFrame(expandgrid(ltt,methods,labels, times, times))
data.columns = ['ltt','method','labels','dtsi','rtsi']
data['nw_score'] = np.random.choice([0,1],data.shape[0])
数据
输出到:
Out[36]:
ltt method labels dtsi rtsi nw_score
0 little1 m1 l1 0 0 1
1 little1 m1 l1 0 4 0
2 little1 m1 l1 0 8 0
3 little1 m1 l1 0 12 1
4 little1 m1 l1 0 16 0
... ... ... ... ... ...
237495 little1 m20 l19 96 80 0
237496 little1 m20 l19 96 84 1
237497 little1 m20 l19 96 88 0
237498 little1 m20 l19 96 92 0
237499 little1 m20 l19 96 96 1
[237500 rows x 6 columns]
绘制并定义facet
函数:
labels_fill = {0:'red',1:'blue'}
del methods
del labels
def facet(data,color):
data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
g = sns.heatmap(data, cmap=ListedColormap(['red', 'blue']), cbar=False,annot=True)
for lt in data.ltt.unique():
with sns.plotting_context(font_scale=5.5):
g = sns.FacetGrid(data[data.ltt==lt],row="labels", col="method", size=2, aspect=1,margin_titles=False)
g = g.map_dataframe(facet)
g.add_legend()
g.set_titles(template="")
for ax,method in zip(g.axes[0,:],data.method.unique()):
ax.set_title(method, fontweight='bold', fontsize=12)
for ax,label in zip(g.axes[:,0],data.labels.unique()):
ax.set_ylabel(label, fontweight='bold', fontsize=12, rotation=0, ha='right', va='center')
g.fig.suptitle(lt, fontweight='bold', fontsize=12)
g.fig.tight_layout()
g.fig.subplots_adjust(top=0.8) # make some room for the title
g.savefig(lt+'.png', dpi=300)
我在一段时间后停止了代码,我们可以看到网格正在被一一填充,这非常耗时。生成此热图的速度慢得难以忍受。
我想知道是否有更好的方法来加快这个过程?
最佳答案
Seaborn 很慢。 如果您使用 matplotlib 而不是 seaborn,则每个图只需半分钟左右。这仍然很长,但考虑到您生成了约 12000x12000 像素的图形,这是可以预料的。
import time
import pandas as pd
import numpy as np
import itertools
import seaborn as sns
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
print("seaborn version {}".format(sns.__version__))
# R expand.grid() function in Python
# https://stackoverflow.com/a/12131385/1135316
def expandgrid(*itrs):
product = list(itertools.product(*itrs))
return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}
ltt= ['little1']
methods=["m" + str(i) for i in range(1,21)]
#methods=['method 1', 'method 2', 'method 3', 'method 4']
#labels = ['label1','label2']
labels=["l" + str(i) for i in range(1,20)]
times = range(0,100,4)
data = pd.DataFrame(expandgrid(ltt,methods,labels, times, times))
data.columns = ['ltt','method','labels','dtsi','rtsi']
#data['nw_score'] = np.random.sample(data.shape[0])
data['nw_score'] = np.random.choice([0,1],data.shape[0])
labels_fill = {0:'red',1:'blue'}
del methods
del labels
cmap=ListedColormap(['red', 'blue'])
def facet(data, ax):
data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
ax.imshow(data, cmap=cmap)
def myfacetgrid(data, row, col, figure=None):
rows = np.unique(data[row].values)
cols = np.unique(data[col].values)
fig, axs = plt.subplots(len(rows), len(cols),
figsize=(2*len(cols)+1, 2*len(rows)+1))
for i, r in enumerate(rows):
row_data = data[data[row] == r]
for j, c in enumerate(cols):
this_data = row_data[row_data[col] == c]
facet(this_data, axs[i,j])
return fig, axs
for lt in data.ltt.unique():
with sns.plotting_context(font_scale=5.5):
t = time.time()
fig, axs = myfacetgrid(data[data.ltt==lt], row="labels", col="method")
print(time.time()-t)
for ax,method in zip(axs[0,:],data.method.unique()):
ax.set_title(method, fontweight='bold', fontsize=12)
for ax,label in zip(axs[:,0],data.labels.unique()):
ax.set_ylabel(label, fontweight='bold', fontsize=12, rotation=0, ha='right', va='center')
print(time.time()-t)
fig.suptitle(lt, fontweight='bold', fontsize=12)
fig.tight_layout()
fig.subplots_adjust(top=0.8) # make some room for the title
print(time.time()-t)
fig.savefig(lt+'.png', dpi=300)
print(time.time()-t)
这里的时间分为约 6 秒创建facetgrid,约 7 秒优化网格布局(通过ight_layout - 考虑忽略它!),以及 15 秒绘制图形。
关于python - 如何加速seaborn热图,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59380853/