Skip to content

Commit

Permalink
Accept AsyncIterables being passed to Response
Browse files Browse the repository at this point in the history
  • Loading branch information
mjsir911 committed May 20, 2024
1 parent 2fc6d4f commit 3df3018
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/quart/wrappers/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,23 @@ async def __anext__(self) -> bytes:


class IterableBody(ResponseBody):
def __init__(self, iterable: AsyncGenerator[bytes, None] | Iterable) -> None:
def __init__(self, iterable: AsyncIterable[bytes] | Iterable) -> None:
self.iter: AsyncGenerator[bytes, None]
if isasyncgen(iterable):
self.iter = iterable
elif isgenerator(iterable):
self.iter = run_sync_iterable(iterable)
elif isinstance(iterable, AsyncIterable):

async def _aiter() -> AsyncGenerator[bytes, None]:
async for data in iterable:
yield data

self.iter = _aiter()
else:

async def _aiter() -> AsyncGenerator[bytes, None]:
for data in iterable: # type: ignore
for data in iterable:
yield data

self.iter = _aiter()
Expand Down Expand Up @@ -262,7 +269,7 @@ class Response(SansIOResponse):

def __init__(
self,
response: ResponseBody | AnyStr | Iterable | None = None,
response: ResponseBody | AnyStr | Iterable | AsyncIterable | None = None,
status: int | None = None,
headers: dict | Headers | None = None,
mimetype: str | None = None,
Expand Down
9 changes: 9 additions & 0 deletions tests/test_templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
g,
Quart,
render_template_string,
Response,
ResponseReturnValue,
session,
stream_template_string,
Expand Down Expand Up @@ -148,3 +149,11 @@ async def index() -> ResponseReturnValue:
test_client = app.test_client()
response = await test_client.get("/")
assert (await response.data) == b"42"

@app.get("/2")
async def index2() -> ResponseReturnValue:
return Response(await stream_template_string("{{ config }}", config=43))

test_client = app.test_client()
response = await test_client.get("/2")
assert (await response.data) == b"43"

0 comments on commit 3df3018

Please sign in to comment.