From a48ff4086451d48de192e13adc947a1fc9db282f Mon Sep 17 00:00:00 2001 From: Adrian Herrmann Date: Fri, 19 Apr 2024 23:23:48 +0200 Subject: QtAsyncio: Fix tasks with loop not cancelling If a task was cancelled, then a new future created from this task should be cancelled as well. Otherwise, in some scenarios like a loop inside the task and with bad timing, if the new future is not cancelled, the task would continue running in this loop despite having been cancelled. This bad timing can occur especially if the first future finishes very quickly. Fixes: PYSIDE-2644 Task-number: PYSIDE-769 Change-Id: Icfff6e4ad5da565f50e3d89fbf85d1fecbf93650 Reviewed-by: Cristian Maureira-Fredes (cherry picked from commit 9de4dee2f697dc88812dfad04ce4054cebf6be61) Reviewed-by: Qt Cherry-pick Bot --- sources/pyside6/PySide6/QtAsyncio/tasks.py | 11 +++++ .../QtAsyncio/qasyncio_test_cancel_taskgroup.py | 57 ++++++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py diff --git a/sources/pyside6/PySide6/QtAsyncio/tasks.py b/sources/pyside6/PySide6/QtAsyncio/tasks.py index bc3d41a73..6777b8bc3 100644 --- a/sources/pyside6/PySide6/QtAsyncio/tasks.py +++ b/sources/pyside6/PySide6/QtAsyncio/tasks.py @@ -29,6 +29,7 @@ class QAsyncioTask(futures.QAsyncioFuture): self._future_to_await: typing.Optional[asyncio.Future] = None self._cancel_message: typing.Optional[str] = None + self._cancelled = False asyncio._register_task(self) # type: ignore[arg-type] @@ -90,6 +91,15 @@ class QAsyncioTask(futures.QAsyncioFuture): result.add_done_callback( self._step, context=self._context) # type: ignore[arg-type] self._future_to_await = result + if self._cancelled: + # If the task was cancelled, then a new future should be + # cancelled as well. Otherwise, in some scenarios like + # a loop inside the task and with bad timing, if the new + # future is not cancelled, the task would continue running + # in this loop despite having been cancelled. This bad + # timing can occur especially if the first future finishes + # very quickly. + self._future_to_await.cancel(self._cancel_message) elif result is None: self._loop.call_soon(self._step, context=self._context) else: @@ -136,6 +146,7 @@ class QAsyncioTask(futures.QAsyncioFuture): self._handle.cancel() if self._future_to_await is not None: self._future_to_await.cancel(msg) + self._cancelled = True return True def uncancel(self) -> None: diff --git a/sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py b/sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py new file mode 100644 index 000000000..aa8ce4718 --- /dev/null +++ b/sources/pyside6/tests/QtAsyncio/qasyncio_test_cancel_taskgroup.py @@ -0,0 +1,57 @@ +# Copyright (C) 2024 The Qt Company Ltd. +# SPDX-License-Identifier: LicenseRef-Qt-Commercial OR GPL-3.0-only WITH Qt-GPL-exception-1.0 + +'''Test cases for QtAsyncio''' + +import asyncio +import unittest + +import PySide6.QtAsyncio as QtAsyncio + + +class QAsyncioTestCaseCancelTaskGroup(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + # We only reach the end of the loop if the task is not cancelled. + self.loop_end_reached = False + + async def raise_error(self): + raise RuntimeError + + async def loop_short(self): + self._loop_end_reached = False + for _ in range(1000): + await asyncio.sleep(1e-3) + self._loop_end_reached = True + + async def loop_shorter(self): + self._loop_end_reached = False + for _ in range(1000): + await asyncio.sleep(1e-4) + self._loop_end_reached = True + + async def loop_the_shortest(self): + self._loop_end_reached = False + for _ in range(1000): + await asyncio.to_thread(lambda: None) + self._loop_end_reached = True + + async def main(self, coro): + async with asyncio.TaskGroup() as tg: + tg.create_task(coro()) + tg.create_task(self.raise_error()) + + def test_cancel_taskgroup(self): + coros = [self.loop_short, self.loop_shorter, self.loop_the_shortest] + + for coro in coros: + try: + QtAsyncio.run(self.main(coro), keep_running=False) + except ExceptionGroup as e: + self.assertEqual(len(e.exceptions), 1) + self.assertIsInstance(e.exceptions[0], RuntimeError) + self.assertFalse(self._loop_end_reached) + + +if __name__ == '__main__': + unittest.main() -- cgit v1.2.3