Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle async cancelled error explicitly #811

Open
wants to merge 24 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
from .._models import Origin, Request, Response
from .._synchronization import AsyncEvent, AsyncLock, AsyncShieldCancellation
from .._synchronization import (
AsyncEvent,
AsyncLock,
AsyncShieldCancellation,
get_cancelled_exc_class,
)
from .connection import AsyncHTTPConnection
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface

Expand Down Expand Up @@ -113,6 +118,7 @@ def __init__(
AutoBackend() if network_backend is None else network_backend
)
self._socket_options = socket_options
self._cancelled_exc = get_cancelled_exc_class()

def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
return AsyncHTTPConnection(
Expand Down Expand Up @@ -231,7 +237,7 @@ async def handle_async_request(self, request: Request) -> Response:
timeout = timeouts.get("pool", None)
try:
connection = await status.wait_for_connection(timeout=timeout)
except BaseException as exc:
except (Exception, self._cancelled_exc) as exc:
# If we timeout here, or if the task is cancelled, then make
# sure to remove the request from the queue before bubbling
# up the exception.
Expand All @@ -256,7 +262,7 @@ async def handle_async_request(self, request: Request) -> Response:
# status so that the request becomes queued again.
status.unset_connection()
await self._attempt_to_acquire_connection(status)
except BaseException as exc:
except (Exception, self._cancelled_exc) as exc:
with AsyncShieldCancellation():
await self.response_closed(status)
raise exc
Expand Down
12 changes: 9 additions & 3 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
map_exceptions,
)
from .._models import Origin, Request, Response
from .._synchronization import AsyncLock, AsyncShieldCancellation
from .._synchronization import (
AsyncLock,
AsyncShieldCancellation,
get_cancelled_exc_class,
)
from .._trace import Trace
from .interfaces import AsyncConnectionInterface

Expand Down Expand Up @@ -67,6 +71,7 @@ def __init__(
our_role=h11.CLIENT,
max_incomplete_event_size=self.MAX_INCOMPLETE_EVENT_SIZE,
)
self._cancelled_exc = get_cancelled_exc_class()

async def handle_async_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
Expand Down Expand Up @@ -126,7 +131,7 @@ async def handle_async_request(self, request: Request) -> Response:
"network_stream": self._network_stream,
},
)
except BaseException as exc:
except (Exception, self._cancelled_exc) as exc:
with AsyncShieldCancellation():
async with Trace("response_closed", logger, request) as trace:
await self._response_closed()
Expand Down Expand Up @@ -321,14 +326,15 @@ def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None:
self._connection = connection
self._request = request
self._closed = False
self._cancelled_exc = get_cancelled_exc_class()

async def __aiter__(self) -> AsyncIterator[bytes]:
kwargs = {"request": self._request}
try:
async with Trace("receive_response_body", logger, self._request, kwargs):
async for chunk in self._connection._receive_response_body(**kwargs):
yield chunk
except BaseException as exc:
except (Exception, self._cancelled_exc) as exc:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
Expand Down
15 changes: 11 additions & 4 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
RemoteProtocolError,
)
from .._models import Origin, Request, Response
from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation
from .._synchronization import (
AsyncLock,
AsyncSemaphore,
AsyncShieldCancellation,
get_cancelled_exc_class,
)
from .._trace import Trace
from .interfaces import AsyncConnectionInterface

Expand Down Expand Up @@ -81,6 +86,7 @@ def __init__(

self._read_exception: typing.Optional[Exception] = None
self._write_exception: typing.Optional[Exception] = None
self._cancelled_exc = get_cancelled_exc_class()

async def handle_async_request(self, request: Request) -> Response:
if not self.can_handle_request(request.url.origin):
Expand All @@ -107,7 +113,7 @@ async def handle_async_request(self, request: Request) -> Response:
kwargs = {"request": request}
async with Trace("send_connection_init", logger, request, kwargs):
await self._send_connection_init(**kwargs)
except BaseException as exc:
except (Exception, self._cancelled_exc) as exc:
with AsyncShieldCancellation():
await self.aclose()
raise exc
Expand Down Expand Up @@ -160,7 +166,7 @@ async def handle_async_request(self, request: Request) -> Response:
"stream_id": stream_id,
},
)
except BaseException as exc: # noqa: PIE786
except (Exception, self._cancelled_exc) as exc: # noqa: PIE786
with AsyncShieldCancellation():
kwargs = {"stream_id": stream_id}
async with Trace("response_closed", logger, request, kwargs):
Expand Down Expand Up @@ -564,6 +570,7 @@ def __init__(
self._request = request
self._stream_id = stream_id
self._closed = False
self._cancelled_exc = get_cancelled_exc_class()

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
kwargs = {"request": self._request, "stream_id": self._stream_id}
Expand All @@ -573,7 +580,7 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
request=self._request, stream_id=self._stream_id
):
yield chunk
except BaseException as exc:
except (Exception, self._cancelled_exc) as exc:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
Expand Down
20 changes: 20 additions & 0 deletions httpcore/_synchronization.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,26 @@ def __exit__(
self._anyio_shield.__exit__(exc_type, exc_value, traceback)


def get_cancelled_exc_class() -> Type[BaseException]:
"""
Detect if we're running under 'asyncio' or 'trio' and return
cannelled exception class of it.
"""
backend = sniffio.current_async_library()
if backend == "trio":
if trio is None: # pragma: nocover
raise RuntimeError(
"Running under trio, requires the 'trio' package to be installed."
)
return trio.Cancelled

if anyio is None: # pragma: nocover
raise RuntimeError(
"Running under asyncio requires the 'anyio' package to be installed."
)
return anyio.get_cancelled_exc_class()


# Our thread-based synchronization primitives...


Expand Down