diff --git a/svgelements/svgelements.py b/svgelements/svgelements.py index eeaa2458..ca22faf9 100644 --- a/svgelements/svgelements.py +++ b/svgelements/svgelements.py @@ -7690,6 +7690,47 @@ def select(self, conditional=None): for s in subitem.select(conditional): yield s + @staticmethod + def union_bbox(elements, transformed=True, with_stroke=False): + """ + Returns the union of the bounding boxes for the elements within the iterator. + + :param transformed: Should the children of this object be properly transformed. + :param with_stroke: should the stroke-width be included in the bounds of the elements + :return: union of all bounding boxes of elements within the iterable. + """ + boxes = [] + for e in elements: + if not hasattr(e, "bbox") or isinstance(e, (Group, Use)): + continue + box = e.bbox(transformed=transformed, with_stroke=with_stroke) + if box is None: + continue + boxes.append(box) + if len(boxes) == 0: + return None + (xmins, ymins, xmaxs, ymaxs) = zip(*boxes) + return (min(xmins), min(ymins), max(xmaxs), max(ymaxs)) + + def bbox(self, transformed=True, with_stroke=False): + """ + Returns the bounding box of the given object. + + In the case of groups this is the union of all the bounding boxes of all bound children. + + Setting transformed to false, may yield unexpected results if subitems are transformed in non-uniform + ways. + + :param transformed: bounding box of the properly transformed children. + :param with_stroke: should the stroke-width be included in the bounds. + :return: bounding box of the given element + """ + return Use.union_bbox( + self.select(), + transformed=transformed, + with_stroke=with_stroke, + ) + class ClipPath(SVGElement, list): """ diff --git a/test/test_use.py b/test/test_use.py index 5ed7defb..02fbdf9e 100644 --- a/test/test_use.py +++ b/test/test_use.py @@ -6,6 +6,22 @@ class TestElementUse(unittest.TestCase): + def test_use_bbox_method(self): + q = io.StringIO(u''' + + + + + + + ''') + svg = SVG.parse(q) + use = list(svg.select(lambda e: isinstance(e, Use))) + self.assertEqual(2, len(use)) + self.assertEqual((0.0, 20.0, (0.0 + 50.0), (20.0 + 50.0)), use[0].bbox()) + self.assertEqual((20.0 + 0.0, 20.0 + 20.0, (20.0 + 50.0), (20.0 + 20.0 + 50.0)), use[1].bbox()) + def test_issue_156(self): q1 = io.StringIO(u'''