python - 使用 Python Plotly 自定义 Networkx 图(或散点图)

标签 python plot plotly networkx

我正在使用 Plotly 的 Python 界面生成网络。我设法用我想要的节点和边创建了一个网络,并控制了节点的大小。 我正在绝望地寻求有关如何执行以下操作的帮助:

  1. 添加节点标签
  2. 根据权重列表添加边缘标签
  3. 根据权重列表控制边缘线宽

所有这一切都没有使用“悬停”选项,因为它必须放在非交互式纸张中。我将不胜感激任何帮助! Plotly's output | In case this fails, the figure itself | matrix.csv 这是我的代码(大部分是从 Networkx 的 Plotly 教程中复制粘贴的):

import pandas as pd
import plotly.plotly as py
from plotly.graph_objs import *
import networkx as nx

matrix = pd.read_csv("matrix.csv", sep = "\t", index_col = 0, header = 0)

G = nx.DiGraph()

# add nodes:
G.add_nodes_from(matrix.columns)

# add edges:
edge_lst = [(i,j, matrix.loc[i,j])
            for i in matrix.index
            for j in matrix.columns
            if matrix.loc[i,j] != 0]
G.add_weighted_edges_from(edge_lst)

# create node trace:
node_trace = Scatter(x = [], y = [], text = [], mode = 'markers',
                    marker = Marker(
                    showscale = True,
                    colorscale = 'YIGnBu',
                    reversescale = True,
                    color = [],
                    size = [],
                    colorbar = dict(
                        thickness = 15,
                        title = 'Node Connections',
                        xanchor = 'left',
                        titleside = 'right'),
                    line = dict(width = 2)))

# set node positions
pos = nx.spring_layout(G)
for node in G.nodes():
    G.node[node]['pos']= pos[node]

for node in G.nodes():
    x, y = G.node[node]['pos']
    node_trace['x'].append(x)
    node_trace['y'].append(y)

# create edge trace:
edge_trace = Scatter(x = [], y = [], text = [],
                     line = Line(width = [], color = '#888'),
                     mode = 'lines')

for edge in G.edges():
    x0, y0 = G.node[edge[0]]['pos']
    x1, y1 = G.node[edge[1]]['pos']
    edge_trace['x'] += [x0, x1, None]
    edge_trace['y'] += [y0, y1, None]
    edge_trace['text'] += str(matrix.loc[edge[0], edge[1]])[:5]

# size nodes by degree
deg_dict = {deg[0]:int(deg[1]) for deg in list(G.degree())}
for node, degree in enumerate(deg_dict):
    node_trace['marker']['size'].append(deg_dict[degree] + 20)

fig = Figure(data = Data([edge_trace, node_trace]),
             layout = Layout(
                 title = '<br>AA Substitution Rates',
                 titlefont = dict(size = 16),
                 showlegend = True,
                 margin = dict(b = 20, l = 5, r = 5, t = 40),
                 annotations = [dict(
                     text = "sub title text",
                     showarrow = False,
                     xref = "paper", yref = "paper",
                     x = 0.005, y = -0.002)],
                 xaxis = XAxis(showgrid = False, 
                               zeroline = False, 
                               showticklabels = False),
                 yaxis = YAxis(showgrid = False, 
                               zeroline = False, 
                               showticklabels = False)))

py.plot(fig, filename = 'networkx')

最佳答案

所以

1. 这个问题的解决方案相对简单,您可以创建一个包含节点 ID 的列表,并将其设置在散点图的 text 属性中。然后将模式设置为“标记+文本”,就完成了。

2. 这有点棘手。您必须计算每行的中间位置并创建一个字典列表,包括该行的中间位置和粗细。然后添加 set 作为布局的注释

3. 这太复杂了,无法使用 plotly IMO 来完成。至于现在,我正在使用 networkx spring_layout 函数计算每个节点的位置。如果您想根据每条线的重量设置每条线的宽度,则必须使用一个函数修改位置,该函数考虑到每条线所附的所有标记。

奖励我为您提供了以不同方式为图表的每个组件着色的选项。

这是我刚才做的一个(稍微修改过的)函数,它执行 12:

import pandas as pd
import plotly.plotly as py
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import networkx as nx

def scatter_plot_2d(G, folderPath, name, savePng = False):
    print("Creating scatter plot (2D)...")

    Nodes = [comp for comp in nx.connected_components(G)] # Looks for the graph's communities
    Edges = G.edges()
    edge_weights = nx.get_edge_attributes(G,'weight')

    labels = [] # names of the nodes to plot
    group = [] # id of the communities
    group_cnt = 0

    print("Communities | Number of Nodes")
    for subgroup in Nodes:
        group_cnt += 1
        print("      %d     |      %d" % (group_cnt, len(subgroup)))
        for node in subgroup:
            labels.append(int(node))
            group.append(group_cnt)

    labels, group = (list(t) for t in zip(*sorted(zip(labels, group))))

    layt = nx.spring_layout(G, dim=2) # Generates the layout of the graph
    Xn = [layt[k][0] for k in list(layt.keys())]  # x-coordinates of nodes
    Yn = [layt[k][1] for k in list(layt.keys())]  # y-coordinates
    Xe = []
    Ye = []

    plot_weights = []
    for e in Edges:
        Xe += [layt[e[0]][0], layt[e[1]][0], None]
        Ye += [layt[e[0]][1], layt[e[1]][1], None]
        ax = (layt[e[0]][0]+layt[e[1]][0])/2
        ay = (layt[e[0]][1]+layt[e[1]][1])/2
        plot_weights.append((edge_weights[(e[0], e[1])], ax, ay))

    annotations_list =[
                        dict(   
                            x=plot_weight[1],
                            y=plot_weight[2],
                            xref='x',
                            yref='y',
                            text=plot_weight[0],
                            showarrow=True,
                            arrowhead=7,
                            ax=plot_weight[1],
                            ay=plot_weight[2]
                        ) 
                        for plot_weight in plot_weights
                    ]

    trace1 = go.Scatter(  x=Xe,
                          y=Ye,
                          mode='lines',
                          line=dict(color='rgb(90, 90, 90)', width=1),
                          hoverinfo='none'
                        )

    trace2 = go.Scatter(  x=Xn,
                          y=Yn,
                          mode='markers+text',
                          name='Nodes',
                          marker=dict(symbol='circle',
                                      size=8,
                                      color=group,
                                      colorscale='Viridis',
                                      line=dict(color='rgb(255,255,255)', width=1)
                                      ),
                          text=labels,
                          textposition='top center',
                          hoverinfo='none'
                          )

    xaxis = dict(
                backgroundcolor="rgb(200, 200, 230)",
                gridcolor="rgb(255, 255, 255)",
                showbackground=True,
                zerolinecolor="rgb(255, 255, 255)"
                )
    yaxis = dict(
                backgroundcolor="rgb(230, 200,230)",
                gridcolor="rgb(255, 255, 255)",
                showbackground=True,
                zerolinecolor="rgb(255, 255, 255)"
                )

    layout = go.Layout(
        title=name,
        width=700,
        height=700,
        showlegend=False,
        plot_bgcolor="rgb(230, 230, 200)",
        scene=dict(
            xaxis=dict(xaxis),
            yaxis=dict(yaxis)
        ),
        margin=dict(
            t=100
        ),
        hovermode='closest',
        annotations=annotations_list
        , )
    data = [trace1, trace2]
    fig = go.Figure(data=data, layout=layout)
    plotDir = folderPath + "/"

    print("Plotting..")

    if savePng:
        plot(fig, filename=plotDir + name + ".html", auto_open=True, image = 'png', image_filename=plotDir + name,
         output_type='file', image_width=700, image_height=700, validate=False)
    else:
        plot(fig, filename=plotDir + name + ".html")

关于python - 使用 Python Plotly 自定义 Networkx 图(或散点图),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50078361/

相关文章:

python - 为什么我的子进程调用需要重新加载页面才能执行?

python - 按 25 个 block 对 CSV 中的行进行分组

r - 将垂直符号添加到绘图中

r - 如何将所有列的名称作为y轴并用ggplot将它们绘制在同一个图表中?

matlab - 如何在 Matlab 绘图中显示较浅边缘的深色边缘?

r - 如何关闭在 Plotly 中显示我的所有数据的工具提示?

用于散热的 Python numpy 矢量化

r - ggplotly : log argument cancels axis labels

python - 从 python 中的 colorlover 的色阶中获取单独的颜色

python - 无法从外部主机在 TCP 端口建立连接