diff --git a/diagrams/__init__.py b/diagrams/__init__.py index 6e282d9..fd34a0f 100644 --- a/diagrams/__init__.py +++ b/diagrams/__init__.py @@ -1,8 +1,9 @@ import contextvars +import html import os import uuid from pathlib import Path -from typing import List, Union, Dict +from typing import List, Union, Dict, Sequence from graphviz import Digraph @@ -36,9 +37,77 @@ def getcluster() -> "Cluster": def setcluster(cluster: "Cluster"): __cluster.set(cluster) +def new_init(cls, init): + def reset_init(*args, **kwargs): + cls.__init__ = init + return reset_init -class Diagram: +class _Cluster: __directions = ("TB", "BT", "LR", "RL") + + def __init__(self, name=None, **kwargs): + self.dot = Digraph(name, **kwargs) + self.depth = 0 + self.nodes = {} + self.subgraphs = [] + + try: + self._parent = getcluster() or getdiagram() + except LookupError: + self._parent = None + + + def __enter__(self): + setcluster(self) + return self + + def __exit__(self, *args): + setcluster(self._parent) + + if not (self.nodes or self.subgraphs): + return + + for node in self.nodes.values(): + self.dot.node(node.nodeid, label=node.label, **node._attrs) + + for subgraph in self.subgraphs: + self.dot.subgraph(subgraph.dot) + + if self._parent: + self._parent.remove_node(self.nodeid) + self._parent.subgraph(self) + + def node(self, node: "Node") -> None: + """Create a new node.""" + self.nodes[node.nodeid] = node + + def remove_node(self, nodeid: str) -> None: + del self.nodes[nodeid] + + def subgraph(self, subgraph: "_Cluster") -> None: + """Create a subgraph for clustering""" + self.subgraphs.append(subgraph) + + @property + def nodes_iter(self): + if self.nodes: + yield from self.nodes.values() + if self.subgraphs: + for subgraph in self.subgraphs: + yield from subgraph.nodes_iter + + def _validate_direction(self, direction: str): + direction = direction.upper() + for v in self.__directions: + if v == direction: + return True + return False + + def __str__(self) -> str: + return str(self.dot) + + +class Diagram(_Cluster): __curvestyles = ("ortho", "curved") __outformats = ("png", "jpg", "svg", "pdf", "dot") @@ -105,15 +174,20 @@ class Diagram: :param edge_attr: Provide edge_attr dot config attributes. :param strict: Rendering should merge multi-edges. """ + self.name = name if not name and not filename: filename = "diagrams_image" elif not filename: filename = "_".join(self.name.split()).lower() self.filename = filename + + super().__init__(self.name, filename=self.filename) + self.edges = {} self.dot = Digraph(self.name, filename=self.filename, strict=strict) # Set attributes. + self.dot.attr(compound="true") for k, v in self._default_graph_attrs.items(): self.dot.graph_attr[k] = v self.dot.graph_attr["label"] = self.name @@ -147,18 +221,29 @@ class Diagram: self.show = show self.autolabel = autolabel - def __str__(self) -> str: - return str(self.dot) - def __enter__(self): setdiagram(self) + super().__enter__() return self + + def __exit__(self, *args): + super().__exit__(*args) + setdiagram(None) + + for (node1, node2), edge in self.edges.items(): + cluster_node1 = next(node1.nodes_iter, None) + if cluster_node1: + edge._attrs['ltail'] = node1.nodeid + node1 = cluster_node1 + cluster_node2 = next(node2.nodes_iter, None) + if cluster_node2: + edge._attrs['lhead'] = node2.nodeid + node2 = cluster_node2 + self.dot.edge(node1.nodeid, node2.nodeid, **edge.attrs) - def __exit__(self, exc_type, exc_value, traceback): self.render() # Remove the graphviz file leaving only the image. os.remove(self.filename) - setdiagram(None) def _repr_png_(self): return self.dot.pipe(format="png") @@ -172,17 +257,9 @@ class Diagram: def _validate_outformat(self, outformat: str) -> bool: return outformat.lower() in self.__outformats - def node(self, nodeid: str, label: str, **attrs) -> None: - """Create a new node.""" - self.dot.node(nodeid, label=label, **attrs) - def connect(self, node: "Node", node2: "Node", edge: "Edge") -> None: """Connect the two Nodes.""" - self.dot.edge(node.nodeid, node2.nodeid, **edge.attrs) - - def subgraph(self, dot: Digraph) -> None: - """Create a subgraph for clustering""" - self.dot.subgraph(dot) + self.edges[(node, node2)] = edge def render(self) -> None: if isinstance(self.outformat, list): @@ -192,8 +269,8 @@ class Diagram: self.dot.render(format=self.outformat, view=self.show, quiet=True) -class Cluster: - __directions = ("TB", "BT", "LR", "RL") +class Node(_Cluster): + """Node represents a node for a specific backend service.""" __bgcolors = ("#E5F5FD", "#EBF3E7", "#ECE8F6", "#FDF7E3") # fmt: off @@ -281,14 +358,56 @@ class Node: _icon_dir = None _icon = None - + _icon_size = 30 + _direction = "LR" _height = 1.9 - def __init__(self, label: str = "", *, nodeid: str = None, **attrs: Dict): + # fmt: on + + def __new__(cls, *args, **kwargs): + instance = object.__new__(cls) + lazy = kwargs.pop('_no_init', False) + if not lazy: + return instance + cls.__init__ = new_init(cls, cls.__init__) + return instance + + def __init__( + self, + label: str = "", + direction: str = None, + icon: object = None, + icon_size: int = None, + **attrs: Dict + ): """Node represents a system component. :param label: Node label. + :param direction: Data flow direction. Default is "LR" (left to right). + :param icon: Custom icon for tihs cluster. Must be a node class or reference. + :param icon_size: The icon size when used as a Cluster. Default is 30. """ + # Generates an ID for identifying a node. + self._id = self._rand_id() + if isinstance(label, str): + self.label = label + elif isinstance(label, Sequence): + self.label = "\n".join(label) + else: + self.label = str(label) + + super().__init__() + + if direction: + if not self._validate_direction(direction): + raise ValueError(f'"{direction}" is not a valid direction') + self._direction = direction + if icon: + _node = icon(_no_init=True) + self._icon = _node._icon + self._icon_dir = _node._icon_dir + if icon_size: + self._icon_size = icon_size # Generates an ID for identifying a node, unless specified self._id = nodeid or self._rand_id() self.label = label @@ -310,11 +429,14 @@ class Node: # that label being spanned between icon image and white space. # Increase the height by the number of new lines included in the label. padding = 0.4 * (self.label.count('\n')) + icon_path = self._load_icon() self._attrs = { "shape": "none", "height": str(self._height + padding), - "image": self._load_icon(), - } if self._icon else {} + "image": icon_path, + } if icon_path else {} + + self._attrs['tooltip'] = (icon if icon else self).__class__.__name__ # fmt: on self._attrs.update(attrs) @@ -322,10 +444,43 @@ class Node: self._cluster = getcluster() # If a node is in the cluster context, add it to cluster. - if self._cluster: - self._cluster.node(self._id, self.label, **self._attrs) + if not self._parent: + raise EnvironmentError("Global diagrams context not set up") + self._parent.node(self) + + def __enter__(self): + super().__enter__() + + # Set attributes. + for k, v in self._default_graph_attrs.items(): + self.dot.graph_attr[k] = v + for k, v in self._attrs.items(): + self.dot.graph_attr[k] = v + + icon = self._load_icon() + if icon: + lines = iter(html.escape(self.label).split("\n")) + self.dot.graph_attr["label"] = '<' +\ + f'' +\ + f'' +\ + ''.join(f'' for line in lines) +\ + '
{next(lines)}
{line}
>' else: - self._diagram.node(self._id, self.label, **self._attrs) + self.dot.graph_attr["label"] = self.label + + self.dot.graph_attr["rankdir"] = self._direction + + # Set cluster depth for distinguishing the background color + self.depth = self._parent.depth + 1 + coloridx = self.depth % len(self.__bgcolors) + self.dot.graph_attr["bgcolor"] = self.__bgcolors[coloridx] + + return self + + def __exit__(self, *args): + super().__exit__(*args) + self._id = "cluster_" + self.nodeid + self.dot.name = self.nodeid def __repr__(self): _name = self.__class__.__name__ @@ -400,7 +555,7 @@ class Node: @property def nodeid(self): return self._id - + # TODO: option for adding flow description to the connection edge def connect(self, node: "Node", edge: "Edge"): """Connect to other node. @@ -414,7 +569,7 @@ class Node: if not isinstance(edge, Edge): ValueError(f"{edge} is not a valid Edge") # An edge must be added on the global diagrams, not a cluster. - self._diagram.connect(self, node, edge) + getdiagram().connect(self, node, edge) return node @staticmethod @@ -422,8 +577,10 @@ class Node: return uuid.uuid4().hex def _load_icon(self): - basedir = Path(os.path.abspath(os.path.dirname(__file__))) - return os.path.join(basedir.parent, self._icon_dir, self._icon) + if self._icon and self._icon_dir: + basedir = Path(os.path.abspath(os.path.dirname(__file__))) + return os.path.join(basedir.parent, self._icon_dir, self._icon) + return None class Edge: @@ -472,6 +629,7 @@ class Edge: # Graphviz complaining about using label for edges, so replace it with xlabel. # Update: xlabel option causes the misaligned label position: https://github.com/mingrammer/diagrams/issues/83 self._attrs["label"] = label + self._attrs["tooltip"] = label if color: self._attrs["color"] = color if style: @@ -544,4 +702,4 @@ class Edge: return {**self._attrs, "dir": direction} -Group = Cluster +Group = Cluster = Node diff --git a/diagrams/aws/__init__.py b/diagrams/aws/__init__.py index 1550a0d..8c912ba 100644 --- a/diagrams/aws/__init__.py +++ b/diagrams/aws/__init__.py @@ -2,7 +2,7 @@ AWS provides a set of services for Amazon Web Service provider. """ -from diagrams import Node +from diagrams import Node, Cluster class _AWS(Node): diff --git a/diagrams/aws/cluster.py b/diagrams/aws/cluster.py new file mode 100644 index 0000000..6ecbabc --- /dev/null +++ b/diagrams/aws/cluster.py @@ -0,0 +1,104 @@ +from diagrams import Cluster +from diagrams.aws.compute import EC2, ApplicationAutoScaling +from diagrams.aws.network import VPC, PrivateSubnet, PublicSubnet + +class Region(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "dotted", + "labeljust": "l", + "pencolor": "#AEB6BE", + "fontname": "Sans-Serif", + "fontsize": "12", + } + # fmt: on + +class AvailabilityZone(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "dashed", + "labeljust": "l", + "pencolor": "#27a0ff", + "fontname": "sans-serif", + "fontsize": "12", + } + # fmt: on + +class VirtualPrivateCloud(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "", + "labeljust": "l", + "pencolor": "#00D110", + "fontname": "sans-serif", + "fontsize": "12", + } + # fmt: on + _icon = VPC + +class PrivateSubnet(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "", + "labeljust": "l", + "pencolor": "#329CFF", + "fontname": "sans-serif", + "fontsize": "12", + } + # fmt: on + _icon = PrivateSubnet + +class PublicSubnet(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "", + "labeljust": "l", + "pencolor": "#00D110", + "fontname": "sans-serif", + "fontsize": "12", + } + # fmt: on + _icon = PublicSubnet + +class SecurityGroup(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "dashed", + "labeljust": "l", + "pencolor": "#FF361E", + "fontname": "Sans-Serif", + "fontsize": "12", + } + # fmt: on + +class AutoScalling(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "dashed", + "labeljust": "l", + "pencolor": "#FF7D1E", + "fontname": "Sans-Serif", + "fontsize": "12", + } + # fmt: on + _icon = ApplicationAutoScaling + +class EC2Contents(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "", + "labeljust": "l", + "pencolor": "#FFB432", + "fontname": "Sans-Serif", + "fontsize": "12", + } + # fmt: on + _icon = EC2 diff --git a/diagrams/azure/cluster.py b/diagrams/azure/cluster.py new file mode 100644 index 0000000..73bfcda --- /dev/null +++ b/diagrams/azure/cluster.py @@ -0,0 +1,143 @@ +from diagrams import Cluster +from diagrams.azure.compute import VM, VMWindows, VMLinux #, VMScaleSet # Depends on PR-404 +from diagrams.azure.network import VirtualNetworks, Subnets, NetworkSecurityGroupsClassic + +class Subscription(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "dotted", + "labeljust": "l", + "pencolor": "#AEB6BE", + "fontname": "Sans-Serif", + "fontsize": "12", + } + # fmt: on + +class Region(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "dotted", + "labeljust": "l", + "pencolor": "#AEB6BE", + "fontname": "Sans-Serif", + "fontsize": "12", + } + # fmt: on + +class AvailabilityZone(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "dashed", + "labeljust": "l", + "pencolor": "#27a0ff", + "fontname": "sans-serif", + "fontsize": "12", + } + # fmt: on + +class VirtualNetwork(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "", + "labeljust": "l", + "pencolor": "#00D110", + "fontname": "sans-serif", + "fontsize": "12", + } + # fmt: on + _icon = VirtualNetworks + +class SubnetWithNSG(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "", + "labeljust": "l", + "pencolor": "#329CFF", + "fontname": "sans-serif", + "fontsize": "12", + } + # fmt: on + _icon = NetworkSecurityGroupsClassic + +class Subnet(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "", + "labeljust": "l", + "pencolor": "#00D110", + "fontname": "sans-serif", + "fontsize": "12", + } + # fmt: on + _icon = Subnets + +class SecurityGroup(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "dashed", + "labeljust": "l", + "pencolor": "#FF361E", + "fontname": "Sans-Serif", + "fontsize": "12", + } + # fmt: on + +class VMContents(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "", + "labeljust": "l", + "pencolor": "#FFB432", + "fontname": "Sans-Serif", + "fontsize": "12", + } + # fmt: on + _icon = VM + +class VMLinuxContents(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "", + "labeljust": "l", + "pencolor": "#FFB432", + "fontname": "Sans-Serif", + "fontsize": "12", + } + # fmt: on + _icon = VMLinux + +class VMWindowsContents(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "", + "labeljust": "l", + "pencolor": "#FFB432", + "fontname": "Sans-Serif", + "fontsize": "12", + } + # fmt: on + _icon = VMWindows + +# Depends on PR-404 +# class VMSS(Cluster): +# # fmt: off +# _default_graph_attrs = { +# "shape": "box", +# "style": "dashed", +# "labeljust": "l", +# "pencolor": "#FF7D1E", +# "fontname": "Sans-Serif", +# "fontsize": "12", +# } +# # fmt: on +# _icon = VMScaleSet diff --git a/diagrams/onprem/cluster.py b/diagrams/onprem/cluster.py new file mode 100644 index 0000000..4fd62f0 --- /dev/null +++ b/diagrams/onprem/cluster.py @@ -0,0 +1,15 @@ +from diagrams import Cluster +from diagrams.onprem.compute import Server + +class ServerContents(Cluster): + # fmt: off + _default_graph_attrs = { + "shape": "box", + "style": "rounded,dotted", + "labeljust": "l", + "pencolor": "#A0A0A0", + "fontname": "Sans-Serif", + "fontsize": "12", + } + # fmt: on + _icon = Server diff --git a/docs/guides/cluster.md b/docs/guides/cluster.md index f4b55b2..056b6c9 100644 --- a/docs/guides/cluster.md +++ b/docs/guides/cluster.md @@ -66,6 +66,36 @@ with Diagram("Event Processing", show=False): handlers >> dw ``` +## Clusters with icons in the label + +You can add a Node icon before the cluster label (and specify its size as well). You need to import the used Node class first. + +It's also possible to use the node in the `with` context adding `cluster=True` to +make it behave like a cluster. + +```python +from diagrams import Cluster, Diagram +from diagrams.aws.compute import ECS +from diagrams.aws.database import RDS, Aurora +from diagrams.aws.network import Route53, VPC + +with Diagram("Simple Web Service with DB Cluster", show=False): + dns = Route53("dns") + web = ECS("service") + + with Cluster(label='VPC',icon=VPC): + with Cluster("DB Cluster",icon=Aurora,icon_size=30): + db_master = RDS("master") + db_master - [RDS("slave1"), + RDS("slave2")] + with Aurora("DB Cluster", cluster=True): + db_master = RDS("master") + db_master - [RDS("slave1"), + RDS("slave2")] + + dns >> web >> db_master +``` + ![event processing diagram](/img/event_processing_diagram.png) > There is no depth limit of nesting. Feel free to create nested clusters as deep as you want. diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/aws.png b/examples/aws.png new file mode 100644 index 0000000..f762f76 Binary files /dev/null and b/examples/aws.png differ diff --git a/examples/aws.py b/examples/aws.py new file mode 100644 index 0000000..3599a5c --- /dev/null +++ b/examples/aws.py @@ -0,0 +1,26 @@ +from diagrams import Diagram, Edge +from diagrams.aws.cluster import * +from diagrams.aws.compute import EC2 +from diagrams.onprem.container import Docker +from diagrams.onprem.cluster import * +from diagrams.aws.network import ELB + +with Diagram(name="", direction="TB", show=True): + with Cluster("AWS"): + with Region("eu-west-1"): + with AvailabilityZone("eu-west-1a"): + with VirtualPrivateCloud(""): + with PrivateSubnet("Private"): + with SecurityGroup("web sg"): + with AutoScalling(""): + with EC2Contents("A"): + d1 = Docker("Container") + with ServerContents("A1"): + d2 = Docker("Container") + + with PublicSubnet("Public"): + with SecurityGroup("elb sg"): + lb = ELB() + + lb >> Edge(forward=True, reverse=True) >> d1 + lb >> Edge(forward=True, reverse=True) >> d2 diff --git a/examples/azure.png b/examples/azure.png new file mode 100644 index 0000000..04af158 Binary files /dev/null and b/examples/azure.png differ diff --git a/examples/azure.py b/examples/azure.py new file mode 100644 index 0000000..e094887 --- /dev/null +++ b/examples/azure.py @@ -0,0 +1,24 @@ +from diagrams import Diagram, Edge +from diagrams.azure.cluster import * +from diagrams.azure.compute import VM +from diagrams.onprem.container import Docker +from diagrams.onprem.cluster import * +from diagrams.azure.network import LoadBalancers + +with Diagram(name="", filename="azure", direction="TB", show=True): + with Cluster("Azure"): + with Region("East US2"): + with AvailabilityZone("Zone 2"): + with VirtualNetwork(""): + with SubnetWithNSG("Private"): + # with VMScaleSet(""): # Depends on PR-404 + with VMContents("A"): + d1 = Docker("Container") + with ServerContents("A1"): + d2 = Docker("Container") + + with Subnet("Public"): + lb = LoadBalancers() + + lb >> Edge(forward=True, reverse=True) >> d1 + lb >> Edge(forward=True, reverse=True) >> d2 diff --git a/tests/test_diagram.py b/tests/test_diagram.py index 00bdacc..28369fe 100644 --- a/tests/test_diagram.py +++ b/tests/test_diagram.py @@ -154,20 +154,20 @@ class ClusterTest(unittest.TestCase): def test_with_global_context(self): with Diagram(name=os.path.join(self.name, "with_global_context"), show=False): - self.assertIsNone(getcluster()) + self.assertEqual(getcluster(), getdiagram()) with Cluster(): - self.assertIsNotNone(getcluster()) - self.assertIsNone(getcluster()) + self.assertNotEqual(getcluster(), getdiagram()) + self.assertEqual(getcluster(), getdiagram()) def test_with_nested_cluster(self): with Diagram(name=os.path.join(self.name, "with_nested_cluster"), show=False): - self.assertIsNone(getcluster()) + self.assertEqual(getcluster(), getdiagram()) with Cluster() as c1: self.assertEqual(c1, getcluster()) with Cluster() as c2: self.assertEqual(c2, getcluster()) self.assertEqual(c1, getcluster()) - self.assertIsNone(getcluster()) + self.assertEqual(getcluster(), getdiagram()) def test_node_not_in_diagram(self): # Node must be belong to a diagrams.