python - 如何加速seaborn热图

标签 python matplotlib seaborn heatmap

正在努力解决searborn facetgrid 热图缓慢的问题。我已经从之前的 problem 扩展了我的数据集并感谢@Diziet Asahi 提供了facetgrid 问题的解决方案。

现在,我有 20x20 网格,每个网格中有 625 个点要映射。即使是一层 little1 也需要很长时间才能获得输出。我的真实数据有数千个层。

我的代码是这样的:

import pandas as pd
import numpy as np
import itertools
import seaborn as sns
from matplotlib.colors import ListedColormap

print("seaborn version {}".format(sns.__version__))
# R expand.grid() function in Python
# https://stackoverflow.com/a/12131385/1135316
def expandgrid(*itrs):
   product = list(itertools.product(*itrs))
   return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}

ltt= ['little1']

methods=["m" + str(i) for i in range(1,21)]
labels=["l" + str(i) for i in range(1,20)]

times = range(0,100,4)
data = pd.DataFrame(expandgrid(ltt,methods,labels, times, times))
data.columns = ['ltt','method','labels','dtsi','rtsi']
data['nw_score'] = np.random.choice([0,1],data.shape[0])

数据输出到:

Out[36]: 
            ltt method labels  dtsi  rtsi  nw_score
0       little1     m1     l1     0     0         1
1       little1     m1     l1     0     4         0
2       little1     m1     l1     0     8         0
3       little1     m1     l1     0    12         1
4       little1     m1     l1     0    16         0
        ...    ...    ...   ...   ...       ...
237495  little1    m20    l19    96    80         0
237496  little1    m20    l19    96    84         1
237497  little1    m20    l19    96    88         0
237498  little1    m20    l19    96    92         0
237499  little1    m20    l19    96    96         1

[237500 rows x 6 columns]

绘制并定义facet函数:

labels_fill = {0:'red',1:'blue'}

del methods
del labels

def facet(data,color):
    data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
    g = sns.heatmap(data, cmap=ListedColormap(['red', 'blue']), cbar=False,annot=True)

for lt in data.ltt.unique():
    with sns.plotting_context(font_scale=5.5):
        g = sns.FacetGrid(data[data.ltt==lt],row="labels", col="method", size=2, aspect=1,margin_titles=False)
        g = g.map_dataframe(facet)
        g.add_legend()
        g.set_titles(template="")

        for ax,method in zip(g.axes[0,:],data.method.unique()):
            ax.set_title(method, fontweight='bold', fontsize=12)
        for ax,label in zip(g.axes[:,0],data.labels.unique()):
            ax.set_ylabel(label, fontweight='bold', fontsize=12, rotation=0, ha='right', va='center')
        g.fig.suptitle(lt, fontweight='bold', fontsize=12)
        g.fig.tight_layout()
        g.fig.subplots_adjust(top=0.8) # make some room for the title

        g.savefig(lt+'.png', dpi=300)
    

enter image description here

我在一段时间后停止了代码,我们可以看到网格正在被一一填充,这非常耗时。生成此热图的速度慢得难以忍受。

我想知道是否有更好的方法来加快这个过程?

最佳答案

Seaborn 很慢。 如果您使用 matplotlib 而不是 seaborn,则每个图只需半分钟左右。这仍然很长,但考虑到您生成了约 12000x12000 像素的图形,这是可以预料的。

import time
import pandas as pd
import numpy as np
import itertools
import seaborn as sns
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt

print("seaborn version {}".format(sns.__version__))
# R expand.grid() function in Python
# https://stackoverflow.com/a/12131385/1135316
def expandgrid(*itrs):
   product = list(itertools.product(*itrs))
   return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}

ltt= ['little1']

methods=["m" + str(i) for i in range(1,21)]
#methods=['method 1', 'method 2', 'method 3', 'method 4']
#labels = ['label1','label2']
labels=["l" + str(i) for i in range(1,20)]

times = range(0,100,4)
data = pd.DataFrame(expandgrid(ltt,methods,labels, times, times))
data.columns = ['ltt','method','labels','dtsi','rtsi']
#data['nw_score'] = np.random.sample(data.shape[0])
data['nw_score'] = np.random.choice([0,1],data.shape[0])

labels_fill = {0:'red',1:'blue'}

del methods
del labels


cmap=ListedColormap(['red', 'blue'])

def facet(data, ax):
    data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
    ax.imshow(data, cmap=cmap)

def myfacetgrid(data, row, col, figure=None):
    rows = np.unique(data[row].values)  
    cols = np.unique(data[col].values)

    fig, axs = plt.subplots(len(rows), len(cols), 
                            figsize=(2*len(cols)+1, 2*len(rows)+1))


    for i, r in enumerate(rows):
        row_data = data[data[row] == r]
        for j, c in enumerate(cols):
            this_data = row_data[row_data[col] == c]
            facet(this_data, axs[i,j])
    return fig, axs


for lt in data.ltt.unique():

    with sns.plotting_context(font_scale=5.5):
        t = time.time()
        fig, axs = myfacetgrid(data[data.ltt==lt], row="labels", col="method")
        print(time.time()-t)
        for ax,method in zip(axs[0,:],data.method.unique()):
            ax.set_title(method, fontweight='bold', fontsize=12)
        for ax,label in zip(axs[:,0],data.labels.unique()):
            ax.set_ylabel(label, fontweight='bold', fontsize=12, rotation=0, ha='right', va='center')
        print(time.time()-t)
        fig.suptitle(lt, fontweight='bold', fontsize=12)
        fig.tight_layout()
        fig.subplots_adjust(top=0.8) # make some room for the title
        print(time.time()-t)
        fig.savefig(lt+'.png', dpi=300)
        print(time.time()-t)

这里的时间分为约 6 秒创建facetgrid,约 7 秒优化网格布局(通过ight_layout - 考虑忽略它!),以及 15 秒绘制图形。

关于python - 如何加速seaborn热图,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59380853/

相关文章:

python - 在 python 3.6 (anaconda3) 中执行 "import pyodbc"时,Jupyter Notebook 抛出错误 : ImportError: DLL load failed

python - 将 JSON 数据列表输出到 JSON 数组并将其写入文件

python - 删除 Matplotlib 中数据点周围的填充

python - 如何绘制给定参数曲线在某个点的法线

python - 应用多图动态轴共享

python - 如何在 Seaborn boxplot 中编辑 mustache 、传单、帽等的属性

python - 带有群图的 Seaborn PairGrid

python - 使用 scalapy 对运行 python 代码的 Scala 应用程序进行 docker 化时出现错误 'Unable to load python3'

python - 防止在 Sympy 中计算乘法表达式

python - 如何在 Seaborn relplot 中自定义标题和 y 标签