Skip to content

Commit

Permalink
Accept AsyncIterables being passed to Response
Browse files Browse the repository at this point in the history
  • Loading branch information
mjsir911 committed May 21, 2024
1 parent 2fc6d4f commit b6fc23f
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 21 deletions.
14 changes: 9 additions & 5 deletions src/quart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Generator,
Iterable,
Iterator,
TYPE_CHECKING,
TypeVar,
)

from werkzeug.datastructures import Headers
Expand Down Expand Up @@ -66,12 +67,15 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any:
return _wrapper


def run_sync_iterable(iterable: Generator[Any, None, None]) -> AsyncGenerator[Any, None]:
async def _gen_wrapper() -> AsyncGenerator[Any, None]:
T = TypeVar("T")


def run_sync_iterable(iterable: Iterator[T]) -> AsyncIterator[T]:
async def _gen_wrapper() -> AsyncIterator[T]:
# Wrap the generator such that each iteration runs
# in the executor. Then rationalise the raised
# errors so that it ends.
def _inner() -> Any:
def _inner() -> T:
# https://bugs.python.org/issue26221
# StopIteration errors are swallowed by the
# run_in_exector method
Expand Down
26 changes: 10 additions & 16 deletions src/quart/wrappers/response.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from builtins import aiter
from hashlib import md5
from inspect import isasyncgen, isgenerator
from io import BytesIO
from os import PathLike
from types import TracebackType
Expand Down Expand Up @@ -102,27 +102,21 @@ async def __anext__(self) -> bytes:


class IterableBody(ResponseBody):
def __init__(self, iterable: AsyncGenerator[bytes, None] | Iterable) -> None:
self.iter: AsyncGenerator[bytes, None]
if isasyncgen(iterable):
self.iter = iterable
elif isgenerator(iterable):
self.iter = run_sync_iterable(iterable)
def __init__(self, iterable: AsyncIterable[Any] | Iterable[Any]) -> None:
self.iter: AsyncIterator[Any]
if isinstance(iterable, Iterable):
self.iter = run_sync_iterable(iter(iterable))
else:

async def _aiter() -> AsyncGenerator[bytes, None]:
for data in iterable: # type: ignore
yield data

self.iter = _aiter()
self.iter = aiter(iterable)

async def __aenter__(self) -> IterableBody:
return self

async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None:
await self.iter.aclose()
if hasattr(self.iter, "aclose"): # Is a generator?
await self.iter.aclose()

def __aiter__(self) -> AsyncIterator:
def __aiter__(self) -> AsyncIterator[Any]:
return self.iter


Expand Down Expand Up @@ -262,7 +256,7 @@ class Response(SansIOResponse):

def __init__(
self,
response: ResponseBody | AnyStr | Iterable | None = None,
response: ResponseBody | AnyStr | Iterable | AsyncIterable | None = None,
status: int | None = None,
headers: dict | Headers | None = None,
mimetype: str | None = None,
Expand Down
9 changes: 9 additions & 0 deletions tests/test_templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
g,
Quart,
render_template_string,
Response,
ResponseReturnValue,
session,
stream_template_string,
Expand Down Expand Up @@ -148,3 +149,11 @@ async def index() -> ResponseReturnValue:
test_client = app.test_client()
response = await test_client.get("/")
assert (await response.data) == b"42"

@app.get("/2")
async def index2() -> ResponseReturnValue:
return Response(await stream_template_string("{{ config }}", config=43))

test_client = app.test_client()
response = await test_client.get("/2")
assert (await response.data) == b"43"

0 comments on commit b6fc23f

Please sign in to comment.