diff --git a/src/quart/utils.py b/src/quart/utils.py index 7e48ace..a92fc03 100644 --- a/src/quart/utils.py +++ b/src/quart/utils.py @@ -10,13 +10,14 @@ from pathlib import Path from typing import ( Any, - AsyncGenerator, + AsyncIterator, Awaitable, Callable, Coroutine, - Generator, Iterable, + Iterator, TYPE_CHECKING, + TypeVar, ) from werkzeug.datastructures import Headers @@ -66,12 +67,15 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any: return _wrapper -def run_sync_iterable(iterable: Generator[Any, None, None]) -> AsyncGenerator[Any, None]: - async def _gen_wrapper() -> AsyncGenerator[Any, None]: +T = TypeVar("T") + + +def run_sync_iterable(iterable: Iterator[T]) -> AsyncIterator[T]: + async def _gen_wrapper() -> AsyncIterator[T]: # Wrap the generator such that each iteration runs # in the executor. Then rationalise the raised # errors so that it ends. - def _inner() -> Any: + def _inner() -> T: # https://bugs.python.org/issue26221 # StopIteration errors are swallowed by the # run_in_exector method diff --git a/src/quart/wrappers/response.py b/src/quart/wrappers/response.py index ad2e6c5..01bf4c0 100644 --- a/src/quart/wrappers/response.py +++ b/src/quart/wrappers/response.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod from hashlib import md5 -from inspect import isasyncgen, isgenerator from io import BytesIO from os import PathLike from types import TracebackType @@ -101,27 +100,21 @@ async def __anext__(self) -> bytes: class IterableBody(ResponseBody): - def __init__(self, iterable: AsyncGenerator[bytes, None] | Iterable) -> None: - self.iter: AsyncGenerator[bytes, None] - if isasyncgen(iterable): - self.iter = iterable - elif isgenerator(iterable): - self.iter = run_sync_iterable(iterable) + def __init__(self, iterable: AsyncIterable[Any] | Iterable[Any]) -> None: + self.iter: AsyncIterator[Any] + if isinstance(iterable, Iterable): + self.iter = run_sync_iterable(iter(iterable)) else: - - async def _aiter() -> AsyncGenerator[bytes, None]: - for data in iterable: # type: ignore - yield data - - self.iter = _aiter() + self.iter = iterable.__aiter__() # Can't use aiter() until 3.10 async def __aenter__(self) -> IterableBody: return self async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None: - await self.iter.aclose() + if hasattr(self.iter, "aclose"): + await self.iter.aclose() - def __aiter__(self) -> AsyncIterator: + def __aiter__(self) -> AsyncIterator[Any]: return self.iter @@ -261,7 +254,7 @@ class Response(SansIOResponse): def __init__( self, - response: ResponseBody | str | bytes | Iterable | None = None, + response: ResponseBody | str | bytes | Iterable | AsyncIterable | None = None, status: int | None = None, headers: dict | Headers | None = None, mimetype: str | None = None, diff --git a/tests/test_templating.py b/tests/test_templating.py index f745020..5df5403 100644 --- a/tests/test_templating.py +++ b/tests/test_templating.py @@ -9,6 +9,7 @@ g, Quart, render_template_string, + Response, ResponseReturnValue, session, stream_template_string, @@ -148,3 +149,11 @@ async def index() -> ResponseReturnValue: test_client = app.test_client() response = await test_client.get("/") assert (await response.data) == b"42" + + @app.get("/2") + async def index2() -> ResponseReturnValue: + return Response(await stream_template_string("{{ config }}", config=43)) + + test_client = app.test_client() + response = await test_client.get("/2") + assert (await response.data) == b"43"