diff --git a/diagrams/__init__.py b/diagrams/__init__.py index 2d16ca6..690a8a0 100644 --- a/diagrams/__init__.py +++ b/diagrams/__init__.py @@ -83,6 +83,7 @@ class Diagram: direction: str = "LR", curvestyle: str = "ortho", outformat: str = "png", + autolabel: bool = False, show: bool = True, graph_attr: dict = {}, node_attr: dict = {}, @@ -142,6 +143,7 @@ class Diagram: self.dot.edge_attr.update(edge_attr) self.show = show + self.autolabel = autolabel def __str__(self) -> str: return str(self.dot) @@ -292,11 +294,23 @@ class Node: self._id = nodeid or self._rand_id() self.label = label + # Node must be belong to a diagrams. + self._diagram = getdiagram() + if self._diagram is None: + raise EnvironmentError("Global diagrams context not set up") + + if self._diagram.autolabel: + prefix = self.__class__.__name__ + if self.label: + self.label = prefix + "\n" + self.label + else: + self.label = prefix + # fmt: off # If a node has an icon, increase the height slightly to avoid # that label being spanned between icon image and white space. # Increase the height by the number of new lines included in the label. - padding = 0.4 * (label.count('\n')) + padding = 0.4 * (self.label.count('\n')) self._attrs = { "shape": "none", "height": str(self._height + padding), @@ -306,10 +320,6 @@ class Node: # fmt: on self._attrs.update(attrs) - # Node must be belong to a diagrams. - self._diagram = getdiagram() - if self._diagram is None: - raise EnvironmentError("Global diagrams context not set up") self._cluster = getcluster() # If a node is in the cluster context, add it to cluster. diff --git a/tests/test_diagram.py b/tests/test_diagram.py index dc0b602..00bdacc 100644 --- a/tests/test_diagram.py +++ b/tests/test_diagram.py @@ -107,6 +107,12 @@ class DiagramTest(unittest.TestCase): with Diagram(show=False): Node("node1") self.assertTrue(os.path.exists(f"{self.name}.png")) + + def test_autolabel(self): + with Diagram(name=os.path.join(self.name, "nodes_to_node"), show=False): + node1 = Node("node1") + self.assertTrue(node1.label,"Node\nnode1") + def test_outformat_list(self): """Check that outformat render all the files from the list."""