diff --git a/diagrams/__init__.py b/diagrams/__init__.py index 7868875..d20af68 100644 --- a/diagrams/__init__.py +++ b/diagrams/__init__.py @@ -39,6 +39,7 @@ def setcluster(cluster): class Diagram: __directions = ("TB", "BT", "LR", "RL") + __curvestyles = ("ortho", "curved") __outformats = ("png", "jpg", "svg", "pdf") # fmt: off @@ -78,6 +79,7 @@ class Diagram: name: str = "", filename: str = "", direction: str = "LR", + curvestyle: str = "ortho", outformat: str = "png", show: bool = True, graph_attr: dict = {}, @@ -91,6 +93,7 @@ class Diagram: :param filename: The output filename, without the extension (.png). If not given, it will be generated from the name. :param direction: Data flow direction. Default is 'left to right'. + :param curvestyle: Curve bending style. One of "ortho" or "curved". :param outformat: Output file format. Default is 'png'. :param show: Open generated image after save if true, just only save otherwise. :param graph_attr: Provide graph_attr dot config attributes. @@ -117,6 +120,10 @@ class Diagram: raise ValueError(f'"{direction}" is not a valid direction') self.dot.graph_attr["rankdir"] = direction + if not self._validate_curvestyle(curvestyle): + raise ValueError(f'"{curvestyle}" is not a valid curvestyle') + self.dot.graph_attr["splines"] = curvestyle + if not self._validate_outformat(outformat): raise ValueError(f'"{outformat}" is not a valid output format') self.outformat = outformat @@ -151,6 +158,13 @@ class Diagram: return True return False + def _validate_curvestyle(self, curvestyle: str) -> bool: + curvestyle = curvestyle.lower() + for v in self.__curvestyles: + if v == curvestyle: + return True + return False + def _validate_outformat(self, outformat: str) -> bool: outformat = outformat.lower() for v in self.__outformats: diff --git a/tests/test_diagram.py b/tests/test_diagram.py index 3cc2128..fd8abc2 100644 --- a/tests/test_diagram.py +++ b/tests/test_diagram.py @@ -33,6 +33,16 @@ class DiagramTest(unittest.TestCase): with self.assertRaises(ValueError): Diagram(direction=dir) + def test_validate_curvestyle(self): + # Normal directions. + for cvs in ("ortho", "curved"): + Diagram(curvestyle=cvs) + + # Invalid directions. + for cvs in ("tangent", "unknown"): + with self.assertRaises(ValueError): + Diagram(curvestyle=cvs) + def test_validate_outformat(self): # Normal output formats. for fmt in ("png", "jpg", "svg", "pdf"):