From 854fff5947b783949d9590306a7e106a6aff490c Mon Sep 17 00:00:00 2001 From: Bruno Meneguello <1322552+bkmeneguello@users.noreply.github.com> Date: Mon, 4 Jan 2021 10:20:15 -0300 Subject: [PATCH] Fix unit tests --- diagrams/__init__.py | 9 ++++----- tests/test_diagram.py | 10 +++++----- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/diagrams/__init__.py b/diagrams/__init__.py index 6261441..6610a78 100644 --- a/diagrams/__init__.py +++ b/diagrams/__init__.py @@ -17,10 +17,7 @@ __cluster = contextvars.ContextVar("cluster") def getdiagram(): - try: - return __diagram.get() - except LookupError: - raise EnvironmentError("Global diagrams context not set up") + return __diagram.get() def setdiagram(diagram): @@ -53,7 +50,7 @@ class _Cluster: try: self._parent = getcluster() or getdiagram() - except EnvironmentError: + except LookupError: self._parent = None @@ -348,6 +345,8 @@ class Node(_Cluster): self._attrs.update(attrs) # If a node is in the cluster context, add it to cluster. + if not self._parent: + raise EnvironmentError("Global diagrams context not set up") self._parent.node(self) def __enter__(self): diff --git a/tests/test_diagram.py b/tests/test_diagram.py index ad8558c..0242e9a 100644 --- a/tests/test_diagram.py +++ b/tests/test_diagram.py @@ -135,20 +135,20 @@ class ClusterTest(unittest.TestCase): def test_with_global_context(self): with Diagram(name=os.path.join(self.name, "with_global_context"), show=False): - self.assertIsNone(getcluster()) + self.assertEqual(getcluster(), getdiagram()) with Cluster(): - self.assertIsNotNone(getcluster()) - self.assertIsNone(getcluster()) + self.assertNotEqual(getcluster(), getdiagram()) + self.assertEqual(getcluster(), getdiagram()) def test_with_nested_cluster(self): with Diagram(name=os.path.join(self.name, "with_nested_cluster"), show=False): - self.assertIsNone(getcluster()) + self.assertEqual(getcluster(), getdiagram()) with Cluster() as c1: self.assertEqual(c1, getcluster()) with Cluster() as c2: self.assertEqual(c2, getcluster()) self.assertEqual(c1, getcluster()) - self.assertIsNone(getcluster()) + self.assertEqual(getcluster(), getdiagram()) def test_node_not_in_diagram(self): # Node must be belong to a diagrams.