我正在练习Python,想要用Python创建一个图,其中包含起始节点和对子节点的一些要求。每个节点的值将是 3 位数字(例如 320、110),我想按以下顺序生成子节点:
- 从第一个数字中减去 1
- 第一个数字加 1
- 从第二个数字中减去 1
- 第二位数字加 1
- 第三位减1
- 第三位数字加 1
起始节点和目标节点的输入来自文本文件,它们可能有第三行,其中包含禁止数字列表,这些数字是搜索算法无法访问的数字。
限制:
- 您不能与数字 9 相加或从数字 0 中减去;
- 您无法采取将当前数字转换为以下数字之一的举动 禁止号码;
- 您不能在连续的两次移动中两次更改相同的数字。
请注意,由于数字有 3 位,因此一开始从起始节点开始最多有 6 种可能的移动方式。 在第一次移动之后,由于移动的限制,特别是由于约束 3,分支因子最多为 4。
我已经为我的图表实现了 Node 类,但在实际构建图表时遇到了问题。
这就是我对 Node 类所做的事情:
class Node(object):
def __init__(self, data):
self.data = data
self.children = []
self.parent = []
def add_child(self, obj):
self.children.append(obj)
def add_parent(self, obj):
self.parent.append(obj)
root = Node(320)
def get_root():
print(root.data)
# some things I've tried
# p = Node(root.data-100)
# p.add_parent(root.data)
# root.add_child(p.data)
# set_root(320)
get_root()
# print(root.data)
# print(root.children)
# print(p.parent)
# p = Node(root.data-100)
我已经实现了一个 BFS,它在给出图形时给出正确的路径输出,但我无法创建在此 BFS 中使用的实际图形。这是我的 BFS:
visited = set()
def bfs(graph_to_search, start, end):
queue = [[start]]
# visited = set()
while queue:
# Gets the first path in the queue
path = queue.pop(0)
# Gets the last node in the path
vertex = path[-1]
# Checks if we got to the end
if vertex == end:
return path
# We check if the current node is already in the visited nodes
set in order not to recheck it
elif vertex not in visited:
# enumerate all adjacent nodes, construct a new path
and push it into the queue
for current_neighbour in graph_to_search.get(vertex[]):
new_path = list(path)
new_path.append(current_neighbour)
queue.append(new_path)
# Mark the vertex as visited
visited.add(vertex)
示例: 起始节点为 320,结束节点为 110,没有禁止节点,在此图上进行 BFS 搜索将如下所示:
任何帮助将不胜感激。谢谢。
最佳答案
首先,您需要创建节点
的模型并生成图表,我们首先必须做出一些假设:
- 这是一个无向图
- 节点之间的距离相等或不重要
节点
将需要某种标识号- 邻居的生成是相对于当前
Node
而言的,因此功能应该是Node
实例的一部分 - 如果我们不指定限制,
Graph
可能会无限生成,因此我们必须引入max_spread
的概念
因此,Node
的代码将如下所示:
from copy import copy
def check_three_digits(value_name, digits):
assert len(digits) == 3, "The {} should be of precise length 3. Actual: {}".format(value_name, digits)
assert digits.isdigit(), "The {} should consist of 3 digits. Actual {}".format(value_name, digits)
class Node:
_node_count = 0
def __init__(self, data: str):
check_three_digits("data param", data)
self._id = Node._node_count
self._data = data
self._neighbours = []
Node._node_count += 1
@property
def id(self):
return self._id
@property
def data(self):
return copy(self._data)
@property
def neighbours(self):
return copy(self._neighbours)
def add_neighbour(self, neighbour):
self._neighbours.append(neighbour)
def _new_neighbour(self, data):
new_neighbour = Node(data)
new_neighbour.add_neighbour(self)
return new_neighbour
def generate_neighbours(self, forbidden_nodes_digits=[]):
first_digit = self._data[0]
second_digit = self._data[1]
third_digit = self._data[2]
first_digit_num = int(first_digit)
second_digit_num = int(second_digit)
third_digit_num = int(third_digit)
sub_first_digit_num = first_digit_num - 1
add_first_digit_num = first_digit_num + 1
sub_second_digit_num = second_digit_num - 1
add_second_digit_num = second_digit_num + 1
sub_third_digit_num = third_digit_num - 1
add_third_digit_num = third_digit_num + 1
sub_first_digit_num = first_digit_num if sub_first_digit_num < 0 else sub_first_digit_num
add_first_digit_num = first_digit_num if add_first_digit_num > 9 else add_first_digit_num
sub_second_digit_num = second_digit_num if sub_second_digit_num < 0 else sub_second_digit_num
add_second_digit_num = second_digit_num if add_second_digit_num > 9 else add_second_digit_num
sub_third_digit_num = third_digit_num if sub_third_digit_num < 0 else sub_third_digit_num
add_third_digit_num = third_digit_num if add_third_digit_num > 9 else add_third_digit_num
for ndigits in [
"{}{}{}".format(str(sub_first_digit_num), second_digit, third_digit),
"{}{}{}".format(str(add_first_digit_num), second_digit, third_digit),
"{}{}{}".format(first_digit, str(sub_second_digit_num), third_digit),
"{}{}{}".format(first_digit, str(add_second_digit_num), third_digit),
"{}{}{}".format(first_digit, second_digit, str(sub_third_digit_num)),
"{}{}{}".format(first_digit, second_digit, str(add_third_digit_num)),
]:
if ndigits in forbidden_nodes_digits:
continue
self._neighbours.append(self._new_neighbour(ndigits))
def __repr__(self):
return str(self)
def __str__(self):
return "Node({})".format(self._data)
为了生成图表,我们有:
def generate_nodes(node, end_node_digits, forbidden_nodes_digits, visited_nodes=None, current_spread=0, max_spread=4):
"""
Handles the generation of the graph.
:node: the current node to generate neighbours for
:end_node_digits: the digits at which to stop spreading further the graph from the current spread.
:visited_nodes: Marks the nodes for which neighbours generation happened, to avoid repetition and infinite recursion.
:current_spread: Marks the current level at which neighbours are being generated.
:max_spread: Defined the max spread over which the graph should no longer generate neighbours for nodes.
"""
# initialize the kwargs with None values
if visited_nodes is None:
visited_nodes = []
# mark the current node as visited
visited_nodes.append(node.id)
# no reason to generate further since we hit the max spread limit
if current_spread >= max_spread:
return
# generate the neighbours for the current node
node.generate_neighbours(forbidden_nodes_digits)
# if we generated the end node, fall back, no need to generate further
if end_node_digits in [n.data for n in node.neighbours]:
return
# make sure to generate neighbours for the current node's neighbours as well
for neighbour in node.neighbours:
if neighbour.id in visited_nodes:
continue
generate_nodes(
neighbour, end_node_digits, forbidden_nodes_digits,
visited_nodes=visited_nodes, current_spread=current_spread + 1, max_spread=max_spread
)
此类模型的广度优先搜索算法如下所示:
def bfs(node, end_node_digits, visited_nodes=None, path=None):
"""
Looks for a specific digit sequence in the graph starting from a specific node.
:node: the node to start search from.
:end_node_digits: The digit sequence to look for.
:visited_nodes: The nodes for which BFS was already performed. Used to avoid infinite recursion and cyclic traversal.
:path: The search path that lead to this node.
"""
# initialize the None kwargs
if visited_nodes is None:
visited_nodes = []
if path is None:
path = ""
path += "({}, {}) ".format(node.id, node.data)
# mark the current node as visited
visited_nodes.append(node.id)
# if we find the end node we can safely report back the match
if node.data == end_node_digits:
return path
# if the current node doesn't match the end node then we look into the neighbours
for neighbour in node.neighbours:
# exclude the visited nodes (obviously excluding the node that generated these nodes)
if neighbour.id in visited_nodes:
continue
# do a BFS in the subdivision of the graph
result_path = bfs(neighbour, end_node_digits, visited_nodes, path)
# if a match was found in the neighbour subdivision, report it back
if result_path is not None:
return result_path
return None
我们可以通过假设 input.txt
为例来举例说明所编写代码的功能,如下所示:
320
221
330 420
和 __main__
block 如下:
if __name__ == '__main__':
# retrieve the nodes from the input file
start_node_digits = None
end_node_digits = None
forbidden_nodes_digits = []
with open("input.txt", "r") as pf:
start_node_digits = pf.readline().strip()
end_node_digits = pf.readline().strip()
forbidden_nodes_digits = pf.readline().split()
forbidden_nodes_digits = [fnode.strip() for fnode in forbidden_nodes_digits]
print("Start node digits: {}".format(start_node_digits))
print("End node digits: {}".format(end_node_digits))
print("Forbidden nodes digits: {}".format(forbidden_nodes_digits))
# validate the input nodes data
check_three_digits("start node", start_node_digits)
check_three_digits("end node", end_node_digits)
for fnode_digits in forbidden_nodes_digits:
check_three_digits("forbidden node", fnode_digits)
# create the first node and generate the graph
first_node = Node(start_node_digits)
print("Generate nodes for graph....")
max_spread = 2
generate_nodes(first_node, end_node_digits, forbidden_nodes_digits, max_spread=max_spread)
# poerform a BFS for a sequence of digits
print("BFS for {}".format(end_node_digits))
match_path = bfs(first_node, end_node_digits)
print("BFS search result: {}".format(match_path))
我们还可以使用这些函数可视化图表:
import networkx as nx
import matplotlib.pyplot as plt
def _draw_node(graph, node, visited_nodes=None):
# initialize kwargs with None values
if visited_nodes is None:
visited_nodes = []
# mark node as visited
visited_nodes.append(node.id)
for neighbour in node.neighbours:
if neighbour.id in visited_nodes:
continue
graph.add_node(neighbour.id)
graph.add_edge(node.id, neighbour.id)
nx.set_node_attributes(graph, {neighbour.id: {'data': neighbour.data}})
_draw_node(graph, neighbour, visited_nodes)
def draw_graph(first_node, start_node_digits, end_node_digits, forbidden_nodes_digits, fig_scale, fig_scale_exponent=1.2):
g = nx.Graph()
# add first node to the draw figure
g.add_node(first_node.id)
nx.set_node_attributes(g, {first_node.id: {'data': first_node.data}})
_draw_node(g, first_node)
# prepare graph drawing
labels = nx.get_node_attributes(g, 'data')
fig = plt.figure(frameon=False)
INCH_FACTOR = 5 # inches
fig_scale = fig_scale ** fig_scale_exponent
fig.set_size_inches(fig_scale * INCH_FACTOR, fig_scale * INCH_FACTOR)
nodes_attributes = nx.get_node_attributes(g, 'data')
color_map = []
for n in g:
ndata = nodes_attributes[n]
if ndata == start_node_digits:
color_map.append('yellow')
elif ndata == end_node_digits:
color_map.append('cyan')
elif ndata in forbidden_nodes_digits:
# just in case something slips
color_map.append('red')
else:
color_map.append("#e5e5e5")
# actually draw the graph and save it to a PNG.
nx.draw_networkx(
g, with_labels=True, labels=labels, node_size=600,
node_color=color_map,
# node_color='#e5e5e5',
font_weight='bold', font_size="10",
pos=nx.drawing.nx_agraph.graphviz_layout(g)
)
plt.savefig("graph.png", dpi=100)
可以在 __main__
block 中调用,例如:
print("Draw graph...")
draw_graph(first_node, start_node_digits, end_node_digits, forbidden_nodes_digits, fig_scale=max_spread, fig_scale_exponent=1)
图表如下所示:
BFS 结果类似于:(0, 320) (1, 220) (10, 221)
现在我不确定这是否完全符合规范,但这应该是一个很好的起点。实现图的方法也有多种,有些人使用顶点和边的列表。
对于 networkx
的 graphviz
,您需要通过 pip 安装 pygraphviz
软件包,如果您使用的是 Linux,则可能需要做 sudo apt-get install graphviz libgraphviz-dev pkg-config
关于python - 如何用Python构建这个图?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55505389/