diff --git a/diagrams/__init__.py b/diagrams/__init__.py
index 6e282d9..fd34a0f 100644
--- a/diagrams/__init__.py
+++ b/diagrams/__init__.py
@@ -1,8 +1,9 @@
import contextvars
+import html
import os
import uuid
from pathlib import Path
-from typing import List, Union, Dict
+from typing import List, Union, Dict, Sequence
from graphviz import Digraph
@@ -36,9 +37,77 @@ def getcluster() -> "Cluster":
def setcluster(cluster: "Cluster"):
__cluster.set(cluster)
+def new_init(cls, init):
+ def reset_init(*args, **kwargs):
+ cls.__init__ = init
+ return reset_init
-class Diagram:
+class _Cluster:
__directions = ("TB", "BT", "LR", "RL")
+
+ def __init__(self, name=None, **kwargs):
+ self.dot = Digraph(name, **kwargs)
+ self.depth = 0
+ self.nodes = {}
+ self.subgraphs = []
+
+ try:
+ self._parent = getcluster() or getdiagram()
+ except LookupError:
+ self._parent = None
+
+
+ def __enter__(self):
+ setcluster(self)
+ return self
+
+ def __exit__(self, *args):
+ setcluster(self._parent)
+
+ if not (self.nodes or self.subgraphs):
+ return
+
+ for node in self.nodes.values():
+ self.dot.node(node.nodeid, label=node.label, **node._attrs)
+
+ for subgraph in self.subgraphs:
+ self.dot.subgraph(subgraph.dot)
+
+ if self._parent:
+ self._parent.remove_node(self.nodeid)
+ self._parent.subgraph(self)
+
+ def node(self, node: "Node") -> None:
+ """Create a new node."""
+ self.nodes[node.nodeid] = node
+
+ def remove_node(self, nodeid: str) -> None:
+ del self.nodes[nodeid]
+
+ def subgraph(self, subgraph: "_Cluster") -> None:
+ """Create a subgraph for clustering"""
+ self.subgraphs.append(subgraph)
+
+ @property
+ def nodes_iter(self):
+ if self.nodes:
+ yield from self.nodes.values()
+ if self.subgraphs:
+ for subgraph in self.subgraphs:
+ yield from subgraph.nodes_iter
+
+ def _validate_direction(self, direction: str):
+ direction = direction.upper()
+ for v in self.__directions:
+ if v == direction:
+ return True
+ return False
+
+ def __str__(self) -> str:
+ return str(self.dot)
+
+
+class Diagram(_Cluster):
__curvestyles = ("ortho", "curved")
__outformats = ("png", "jpg", "svg", "pdf", "dot")
@@ -105,15 +174,20 @@ class Diagram:
:param edge_attr: Provide edge_attr dot config attributes.
:param strict: Rendering should merge multi-edges.
"""
+
self.name = name
if not name and not filename:
filename = "diagrams_image"
elif not filename:
filename = "_".join(self.name.split()).lower()
self.filename = filename
+
+ super().__init__(self.name, filename=self.filename)
+ self.edges = {}
self.dot = Digraph(self.name, filename=self.filename, strict=strict)
# Set attributes.
+ self.dot.attr(compound="true")
for k, v in self._default_graph_attrs.items():
self.dot.graph_attr[k] = v
self.dot.graph_attr["label"] = self.name
@@ -147,18 +221,29 @@ class Diagram:
self.show = show
self.autolabel = autolabel
- def __str__(self) -> str:
- return str(self.dot)
-
def __enter__(self):
setdiagram(self)
+ super().__enter__()
return self
+
+ def __exit__(self, *args):
+ super().__exit__(*args)
+ setdiagram(None)
+
+ for (node1, node2), edge in self.edges.items():
+ cluster_node1 = next(node1.nodes_iter, None)
+ if cluster_node1:
+ edge._attrs['ltail'] = node1.nodeid
+ node1 = cluster_node1
+ cluster_node2 = next(node2.nodes_iter, None)
+ if cluster_node2:
+ edge._attrs['lhead'] = node2.nodeid
+ node2 = cluster_node2
+ self.dot.edge(node1.nodeid, node2.nodeid, **edge.attrs)
- def __exit__(self, exc_type, exc_value, traceback):
self.render()
# Remove the graphviz file leaving only the image.
os.remove(self.filename)
- setdiagram(None)
def _repr_png_(self):
return self.dot.pipe(format="png")
@@ -172,17 +257,9 @@ class Diagram:
def _validate_outformat(self, outformat: str) -> bool:
return outformat.lower() in self.__outformats
- def node(self, nodeid: str, label: str, **attrs) -> None:
- """Create a new node."""
- self.dot.node(nodeid, label=label, **attrs)
-
def connect(self, node: "Node", node2: "Node", edge: "Edge") -> None:
"""Connect the two Nodes."""
- self.dot.edge(node.nodeid, node2.nodeid, **edge.attrs)
-
- def subgraph(self, dot: Digraph) -> None:
- """Create a subgraph for clustering"""
- self.dot.subgraph(dot)
+ self.edges[(node, node2)] = edge
def render(self) -> None:
if isinstance(self.outformat, list):
@@ -192,8 +269,8 @@ class Diagram:
self.dot.render(format=self.outformat, view=self.show, quiet=True)
-class Cluster:
- __directions = ("TB", "BT", "LR", "RL")
+class Node(_Cluster):
+ """Node represents a node for a specific backend service."""
__bgcolors = ("#E5F5FD", "#EBF3E7", "#ECE8F6", "#FDF7E3")
# fmt: off
@@ -281,14 +358,56 @@ class Node:
_icon_dir = None
_icon = None
-
+ _icon_size = 30
+ _direction = "LR"
_height = 1.9
- def __init__(self, label: str = "", *, nodeid: str = None, **attrs: Dict):
+ # fmt: on
+
+ def __new__(cls, *args, **kwargs):
+ instance = object.__new__(cls)
+ lazy = kwargs.pop('_no_init', False)
+ if not lazy:
+ return instance
+ cls.__init__ = new_init(cls, cls.__init__)
+ return instance
+
+ def __init__(
+ self,
+ label: str = "",
+ direction: str = None,
+ icon: object = None,
+ icon_size: int = None,
+ **attrs: Dict
+ ):
"""Node represents a system component.
:param label: Node label.
+ :param direction: Data flow direction. Default is "LR" (left to right).
+ :param icon: Custom icon for tihs cluster. Must be a node class or reference.
+ :param icon_size: The icon size when used as a Cluster. Default is 30.
"""
+ # Generates an ID for identifying a node.
+ self._id = self._rand_id()
+ if isinstance(label, str):
+ self.label = label
+ elif isinstance(label, Sequence):
+ self.label = "\n".join(label)
+ else:
+ self.label = str(label)
+
+ super().__init__()
+
+ if direction:
+ if not self._validate_direction(direction):
+ raise ValueError(f'"{direction}" is not a valid direction')
+ self._direction = direction
+ if icon:
+ _node = icon(_no_init=True)
+ self._icon = _node._icon
+ self._icon_dir = _node._icon_dir
+ if icon_size:
+ self._icon_size = icon_size
# Generates an ID for identifying a node, unless specified
self._id = nodeid or self._rand_id()
self.label = label
@@ -310,11 +429,14 @@ class Node:
# that label being spanned between icon image and white space.
# Increase the height by the number of new lines included in the label.
padding = 0.4 * (self.label.count('\n'))
+ icon_path = self._load_icon()
self._attrs = {
"shape": "none",
"height": str(self._height + padding),
- "image": self._load_icon(),
- } if self._icon else {}
+ "image": icon_path,
+ } if icon_path else {}
+
+ self._attrs['tooltip'] = (icon if icon else self).__class__.__name__
# fmt: on
self._attrs.update(attrs)
@@ -322,10 +444,43 @@ class Node:
self._cluster = getcluster()
# If a node is in the cluster context, add it to cluster.
- if self._cluster:
- self._cluster.node(self._id, self.label, **self._attrs)
+ if not self._parent:
+ raise EnvironmentError("Global diagrams context not set up")
+ self._parent.node(self)
+
+ def __enter__(self):
+ super().__enter__()
+
+ # Set attributes.
+ for k, v in self._default_graph_attrs.items():
+ self.dot.graph_attr[k] = v
+ for k, v in self._attrs.items():
+ self.dot.graph_attr[k] = v
+
+ icon = self._load_icon()
+ if icon:
+ lines = iter(html.escape(self.label).split("\n"))
+ self.dot.graph_attr["label"] = '<
' +\
+ f' | ' +\
+ f'{next(lines)} |
' +\
+ ''.join(f'{line} |
' for line in lines) +\
+ '
>'
else:
- self._diagram.node(self._id, self.label, **self._attrs)
+ self.dot.graph_attr["label"] = self.label
+
+ self.dot.graph_attr["rankdir"] = self._direction
+
+ # Set cluster depth for distinguishing the background color
+ self.depth = self._parent.depth + 1
+ coloridx = self.depth % len(self.__bgcolors)
+ self.dot.graph_attr["bgcolor"] = self.__bgcolors[coloridx]
+
+ return self
+
+ def __exit__(self, *args):
+ super().__exit__(*args)
+ self._id = "cluster_" + self.nodeid
+ self.dot.name = self.nodeid
def __repr__(self):
_name = self.__class__.__name__
@@ -400,7 +555,7 @@ class Node:
@property
def nodeid(self):
return self._id
-
+
# TODO: option for adding flow description to the connection edge
def connect(self, node: "Node", edge: "Edge"):
"""Connect to other node.
@@ -414,7 +569,7 @@ class Node:
if not isinstance(edge, Edge):
ValueError(f"{edge} is not a valid Edge")
# An edge must be added on the global diagrams, not a cluster.
- self._diagram.connect(self, node, edge)
+ getdiagram().connect(self, node, edge)
return node
@staticmethod
@@ -422,8 +577,10 @@ class Node:
return uuid.uuid4().hex
def _load_icon(self):
- basedir = Path(os.path.abspath(os.path.dirname(__file__)))
- return os.path.join(basedir.parent, self._icon_dir, self._icon)
+ if self._icon and self._icon_dir:
+ basedir = Path(os.path.abspath(os.path.dirname(__file__)))
+ return os.path.join(basedir.parent, self._icon_dir, self._icon)
+ return None
class Edge:
@@ -472,6 +629,7 @@ class Edge:
# Graphviz complaining about using label for edges, so replace it with xlabel.
# Update: xlabel option causes the misaligned label position: https://github.com/mingrammer/diagrams/issues/83
self._attrs["label"] = label
+ self._attrs["tooltip"] = label
if color:
self._attrs["color"] = color
if style:
@@ -544,4 +702,4 @@ class Edge:
return {**self._attrs, "dir": direction}
-Group = Cluster
+Group = Cluster = Node
diff --git a/diagrams/aws/__init__.py b/diagrams/aws/__init__.py
index 1550a0d..8c912ba 100644
--- a/diagrams/aws/__init__.py
+++ b/diagrams/aws/__init__.py
@@ -2,7 +2,7 @@
AWS provides a set of services for Amazon Web Service provider.
"""
-from diagrams import Node
+from diagrams import Node, Cluster
class _AWS(Node):
diff --git a/diagrams/aws/cluster.py b/diagrams/aws/cluster.py
new file mode 100644
index 0000000..6ecbabc
--- /dev/null
+++ b/diagrams/aws/cluster.py
@@ -0,0 +1,104 @@
+from diagrams import Cluster
+from diagrams.aws.compute import EC2, ApplicationAutoScaling
+from diagrams.aws.network import VPC, PrivateSubnet, PublicSubnet
+
+class Region(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "dotted",
+ "labeljust": "l",
+ "pencolor": "#AEB6BE",
+ "fontname": "Sans-Serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+
+class AvailabilityZone(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "dashed",
+ "labeljust": "l",
+ "pencolor": "#27a0ff",
+ "fontname": "sans-serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+
+class VirtualPrivateCloud(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "",
+ "labeljust": "l",
+ "pencolor": "#00D110",
+ "fontname": "sans-serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+ _icon = VPC
+
+class PrivateSubnet(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "",
+ "labeljust": "l",
+ "pencolor": "#329CFF",
+ "fontname": "sans-serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+ _icon = PrivateSubnet
+
+class PublicSubnet(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "",
+ "labeljust": "l",
+ "pencolor": "#00D110",
+ "fontname": "sans-serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+ _icon = PublicSubnet
+
+class SecurityGroup(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "dashed",
+ "labeljust": "l",
+ "pencolor": "#FF361E",
+ "fontname": "Sans-Serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+
+class AutoScalling(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "dashed",
+ "labeljust": "l",
+ "pencolor": "#FF7D1E",
+ "fontname": "Sans-Serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+ _icon = ApplicationAutoScaling
+
+class EC2Contents(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "",
+ "labeljust": "l",
+ "pencolor": "#FFB432",
+ "fontname": "Sans-Serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+ _icon = EC2
diff --git a/diagrams/azure/cluster.py b/diagrams/azure/cluster.py
new file mode 100644
index 0000000..73bfcda
--- /dev/null
+++ b/diagrams/azure/cluster.py
@@ -0,0 +1,143 @@
+from diagrams import Cluster
+from diagrams.azure.compute import VM, VMWindows, VMLinux #, VMScaleSet # Depends on PR-404
+from diagrams.azure.network import VirtualNetworks, Subnets, NetworkSecurityGroupsClassic
+
+class Subscription(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "dotted",
+ "labeljust": "l",
+ "pencolor": "#AEB6BE",
+ "fontname": "Sans-Serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+
+class Region(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "dotted",
+ "labeljust": "l",
+ "pencolor": "#AEB6BE",
+ "fontname": "Sans-Serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+
+class AvailabilityZone(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "dashed",
+ "labeljust": "l",
+ "pencolor": "#27a0ff",
+ "fontname": "sans-serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+
+class VirtualNetwork(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "",
+ "labeljust": "l",
+ "pencolor": "#00D110",
+ "fontname": "sans-serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+ _icon = VirtualNetworks
+
+class SubnetWithNSG(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "",
+ "labeljust": "l",
+ "pencolor": "#329CFF",
+ "fontname": "sans-serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+ _icon = NetworkSecurityGroupsClassic
+
+class Subnet(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "",
+ "labeljust": "l",
+ "pencolor": "#00D110",
+ "fontname": "sans-serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+ _icon = Subnets
+
+class SecurityGroup(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "dashed",
+ "labeljust": "l",
+ "pencolor": "#FF361E",
+ "fontname": "Sans-Serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+
+class VMContents(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "",
+ "labeljust": "l",
+ "pencolor": "#FFB432",
+ "fontname": "Sans-Serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+ _icon = VM
+
+class VMLinuxContents(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "",
+ "labeljust": "l",
+ "pencolor": "#FFB432",
+ "fontname": "Sans-Serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+ _icon = VMLinux
+
+class VMWindowsContents(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "",
+ "labeljust": "l",
+ "pencolor": "#FFB432",
+ "fontname": "Sans-Serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+ _icon = VMWindows
+
+# Depends on PR-404
+# class VMSS(Cluster):
+# # fmt: off
+# _default_graph_attrs = {
+# "shape": "box",
+# "style": "dashed",
+# "labeljust": "l",
+# "pencolor": "#FF7D1E",
+# "fontname": "Sans-Serif",
+# "fontsize": "12",
+# }
+# # fmt: on
+# _icon = VMScaleSet
diff --git a/diagrams/onprem/cluster.py b/diagrams/onprem/cluster.py
new file mode 100644
index 0000000..4fd62f0
--- /dev/null
+++ b/diagrams/onprem/cluster.py
@@ -0,0 +1,15 @@
+from diagrams import Cluster
+from diagrams.onprem.compute import Server
+
+class ServerContents(Cluster):
+ # fmt: off
+ _default_graph_attrs = {
+ "shape": "box",
+ "style": "rounded,dotted",
+ "labeljust": "l",
+ "pencolor": "#A0A0A0",
+ "fontname": "Sans-Serif",
+ "fontsize": "12",
+ }
+ # fmt: on
+ _icon = Server
diff --git a/docs/guides/cluster.md b/docs/guides/cluster.md
index f4b55b2..056b6c9 100644
--- a/docs/guides/cluster.md
+++ b/docs/guides/cluster.md
@@ -66,6 +66,36 @@ with Diagram("Event Processing", show=False):
handlers >> dw
```
+## Clusters with icons in the label
+
+You can add a Node icon before the cluster label (and specify its size as well). You need to import the used Node class first.
+
+It's also possible to use the node in the `with` context adding `cluster=True` to
+make it behave like a cluster.
+
+```python
+from diagrams import Cluster, Diagram
+from diagrams.aws.compute import ECS
+from diagrams.aws.database import RDS, Aurora
+from diagrams.aws.network import Route53, VPC
+
+with Diagram("Simple Web Service with DB Cluster", show=False):
+ dns = Route53("dns")
+ web = ECS("service")
+
+ with Cluster(label='VPC',icon=VPC):
+ with Cluster("DB Cluster",icon=Aurora,icon_size=30):
+ db_master = RDS("master")
+ db_master - [RDS("slave1"),
+ RDS("slave2")]
+ with Aurora("DB Cluster", cluster=True):
+ db_master = RDS("master")
+ db_master - [RDS("slave1"),
+ RDS("slave2")]
+
+ dns >> web >> db_master
+```
+
![event processing diagram](/img/event_processing_diagram.png)
> There is no depth limit of nesting. Feel free to create nested clusters as deep as you want.
diff --git a/examples/__init__.py b/examples/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/examples/aws.png b/examples/aws.png
new file mode 100644
index 0000000..f762f76
Binary files /dev/null and b/examples/aws.png differ
diff --git a/examples/aws.py b/examples/aws.py
new file mode 100644
index 0000000..3599a5c
--- /dev/null
+++ b/examples/aws.py
@@ -0,0 +1,26 @@
+from diagrams import Diagram, Edge
+from diagrams.aws.cluster import *
+from diagrams.aws.compute import EC2
+from diagrams.onprem.container import Docker
+from diagrams.onprem.cluster import *
+from diagrams.aws.network import ELB
+
+with Diagram(name="", direction="TB", show=True):
+ with Cluster("AWS"):
+ with Region("eu-west-1"):
+ with AvailabilityZone("eu-west-1a"):
+ with VirtualPrivateCloud(""):
+ with PrivateSubnet("Private"):
+ with SecurityGroup("web sg"):
+ with AutoScalling(""):
+ with EC2Contents("A"):
+ d1 = Docker("Container")
+ with ServerContents("A1"):
+ d2 = Docker("Container")
+
+ with PublicSubnet("Public"):
+ with SecurityGroup("elb sg"):
+ lb = ELB()
+
+ lb >> Edge(forward=True, reverse=True) >> d1
+ lb >> Edge(forward=True, reverse=True) >> d2
diff --git a/examples/azure.png b/examples/azure.png
new file mode 100644
index 0000000..04af158
Binary files /dev/null and b/examples/azure.png differ
diff --git a/examples/azure.py b/examples/azure.py
new file mode 100644
index 0000000..e094887
--- /dev/null
+++ b/examples/azure.py
@@ -0,0 +1,24 @@
+from diagrams import Diagram, Edge
+from diagrams.azure.cluster import *
+from diagrams.azure.compute import VM
+from diagrams.onprem.container import Docker
+from diagrams.onprem.cluster import *
+from diagrams.azure.network import LoadBalancers
+
+with Diagram(name="", filename="azure", direction="TB", show=True):
+ with Cluster("Azure"):
+ with Region("East US2"):
+ with AvailabilityZone("Zone 2"):
+ with VirtualNetwork(""):
+ with SubnetWithNSG("Private"):
+ # with VMScaleSet(""): # Depends on PR-404
+ with VMContents("A"):
+ d1 = Docker("Container")
+ with ServerContents("A1"):
+ d2 = Docker("Container")
+
+ with Subnet("Public"):
+ lb = LoadBalancers()
+
+ lb >> Edge(forward=True, reverse=True) >> d1
+ lb >> Edge(forward=True, reverse=True) >> d2
diff --git a/tests/test_diagram.py b/tests/test_diagram.py
index 00bdacc..28369fe 100644
--- a/tests/test_diagram.py
+++ b/tests/test_diagram.py
@@ -154,20 +154,20 @@ class ClusterTest(unittest.TestCase):
def test_with_global_context(self):
with Diagram(name=os.path.join(self.name, "with_global_context"), show=False):
- self.assertIsNone(getcluster())
+ self.assertEqual(getcluster(), getdiagram())
with Cluster():
- self.assertIsNotNone(getcluster())
- self.assertIsNone(getcluster())
+ self.assertNotEqual(getcluster(), getdiagram())
+ self.assertEqual(getcluster(), getdiagram())
def test_with_nested_cluster(self):
with Diagram(name=os.path.join(self.name, "with_nested_cluster"), show=False):
- self.assertIsNone(getcluster())
+ self.assertEqual(getcluster(), getdiagram())
with Cluster() as c1:
self.assertEqual(c1, getcluster())
with Cluster() as c2:
self.assertEqual(c2, getcluster())
self.assertEqual(c1, getcluster())
- self.assertIsNone(getcluster())
+ self.assertEqual(getcluster(), getdiagram())
def test_node_not_in_diagram(self):
# Node must be belong to a diagrams.