From 496a69cb4f8ec8cfe5a6d6f4d82b9db14215682a Mon Sep 17 00:00:00 2001 From: Bruno Meneguello <1322552+bkmeneguello@users.noreply.github.com> Date: Mon, 28 Dec 2020 16:52:17 -0300 Subject: [PATCH] Allow node to cluster edges --- diagrams/__init__.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/diagrams/__init__.py b/diagrams/__init__.py index 556e043..4b19d9b 100644 --- a/diagrams/__init__.py +++ b/diagrams/__init__.py @@ -115,7 +115,9 @@ class Diagram: self.dot = Digraph(self.name, filename=self.filename) self._nodes = {} + self._edges = {} + self.dot.attr(compound="true") # Set attributes. for k, v in self._default_graph_attrs.items(): self.dot.graph_attr[k] = v @@ -155,6 +157,17 @@ class Diagram: for nodeid, node in self._nodes.items(): self.dot.node(nodeid, label=node['label'], **node['attrs']) + for nodes, edge in self._edges.items(): + node1, node2 = nodes + nodeid1, nodeid2 = node1.nodeid, node2.nodeid + if hasattr(node1, '_nodes') and node1._nodes: + edge._attrs['ltail'] = nodeid1 + nodeid1 = next(iter(node1._nodes.keys())) + if hasattr(node2, '_nodes') and node2._nodes: + edge._attrs['lhead'] = nodeid2 + nodeid2 = next(iter(node2._nodes.keys())) + self.dot.edge(nodeid1, nodeid2, **edge.attrs) + self.render() # Remove the graphviz file leaving only the image. os.remove(self.filename) @@ -193,7 +206,7 @@ class Diagram: def connect(self, node: "Node", node2: "Node", edge: "Edge") -> None: """Connect the two Nodes.""" - self.dot.edge(node.nodeid, node2.nodeid, **edge.attrs) + self._edges[(node, node2)] = edge def subgraph(self, dot: Digraph) -> None: """Create a subgraph for clustering""" @@ -387,16 +400,16 @@ class Node: self._diagram.node(self._id, self.label, **self._attrs) def __enter__(self): - setcluster(self) - self.name = "cluster_" + self.label - self.dot = Digraph(self.name) - self._nodes = {} - if self._cluster: self._cluster.remove_node(self._id) else: self._diagram.remove_node(self._id) + setcluster(self) + self._id = "cluster_" + self.label + self.dot = Digraph(self._id) + self._nodes = {} + # Set attributes. for k, v in self._default_graph_attrs.items(): self.dot.graph_attr[k] = v @@ -421,7 +434,7 @@ class Node: def __exit__(self, exc_type, exc_value, traceback): for nodeid, node in self._nodes.items(): self.dot.node(nodeid, label=node['label'], **node['attrs']) - + if self._cluster: self._cluster.subgraph(self.dot) else: