aboutsummaryrefslogtreecommitdiffstats
path: root/sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py
diff options
context:
space:
mode:
Diffstat (limited to 'sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py')
-rw-r--r--sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py b/sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py
new file mode 100644
index 000000000..aa8ce4718
--- /dev/null
+++ b/sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py
@@ -0,0 +1,57 @@
+# Copyright (C) 2024 The Qt Company Ltd.
+# SPDX-License-Identifier: LicenseRef-Qt-Commercial OR GPL-3.0-only WITH Qt-GPL-exception-1.0
+
+'''Test cases for QtAsyncio'''
+
+import asyncio
+import unittest
+
+import PySide6.QtAsyncio as QtAsyncio
+
+
+class QAsyncioTestCaseCancelTaskGroup(unittest.TestCase):
+ def setUp(self) -> None:
+ super().setUp()
+ # We only reach the end of the loop if the task is not cancelled.
+ self.loop_end_reached = False
+
+ async def raise_error(self):
+ raise RuntimeError
+
+ async def loop_short(self):
+ self._loop_end_reached = False
+ for _ in range(1000):
+ await asyncio.sleep(1e-3)
+ self._loop_end_reached = True
+
+ async def loop_shorter(self):
+ self._loop_end_reached = False
+ for _ in range(1000):
+ await asyncio.sleep(1e-4)
+ self._loop_end_reached = True
+
+ async def loop_the_shortest(self):
+ self._loop_end_reached = False
+ for _ in range(1000):
+ await asyncio.to_thread(lambda: None)
+ self._loop_end_reached = True
+
+ async def main(self, coro):
+ async with asyncio.TaskGroup() as tg:
+ tg.create_task(coro())
+ tg.create_task(self.raise_error())
+
+ def test_cancel_taskgroup(self):
+ coros = [self.loop_short, self.loop_shorter, self.loop_the_shortest]
+
+ for coro in coros:
+ try:
+ QtAsyncio.run(self.main(coro), keep_running=False)
+ except ExceptionGroup as e:
+ self.assertEqual(len(e.exceptions), 1)
+ self.assertIsInstance(e.exceptions[0], RuntimeError)
+ self.assertFalse(self._loop_end_reached)
+
+
+if __name__ == '__main__':
+ unittest.main()