Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support bytes as query value #3212

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
8 changes: 6 additions & 2 deletions httpx/_urlparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ipaddress
import re
import typing
from urllib.parse import quote_from_bytes

import idna

Expand Down Expand Up @@ -419,10 +420,13 @@ def PERCENT(string: str) -> str:
return "".join([f"%{byte:02X}" for byte in string.encode("utf-8")])


def percent_encoded(string: str, safe: str = "/") -> str:
def percent_encoded(string: str | bytes, safe: str = "/") -> str:
"""
Use percent-encoding to quote a string.
"""
if isinstance(string, bytes):
return quote_from_bytes(string)

NON_ESCAPED_CHARS = UNRESERVED_CHARACTERS + safe

# Fast path for strings that don't need escaping.
Expand Down Expand Up @@ -467,7 +471,7 @@ def quote(string: str, safe: str = "/") -> str:
return "".join(parts)


def urlencode(items: list[tuple[str, str]]) -> str:
def urlencode(items: list[tuple[str, str | bytes]]) -> str:
"""
We can use a much simpler version of the stdlib urlencode here because
we don't need to handle a bunch of different typing cases, such as bytes vs str.
Expand Down
52 changes: 36 additions & 16 deletions httpx/_urls.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import copy
import typing
from urllib.parse import parse_qs, unquote

import idna

from ._types import QueryParamTypes, RawURL, URLTypes
from ._urlparse import urlencode, urlparse
from ._utils import primitive_value_to_str
from ._utils import encode_query_value

__all__ = ["URL", "QueryParams"]

Expand Down Expand Up @@ -417,20 +418,30 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({url!r})"


class QueryParams(typing.Mapping[str, str]):
class QueryParams(typing.Mapping[str, typing.Union[str, bytes]]):
"""
URL query parameters, as a multi-dict.
"""

def __init__(self, *args: QueryParamTypes | None, **kwargs: typing.Any) -> None:
_dict: dict[str, list[str | bytes]]

__slots__ = ("_dict",)

@typing.overload
def __init__(self, qs: QueryParamTypes | None, /) -> None: ...

@typing.overload
def __init__(self, /, **kwargs: typing.Any) -> None: ...

def __init__(self, /, *args: QueryParamTypes | None, **kwargs: typing.Any) -> None:
assert len(args) < 2, "Too many arguments."
assert not (args and kwargs), "Cannot mix named and unnamed arguments."

value = args[0] if args else kwargs

if value is None or isinstance(value, (str, bytes)):
value = value.decode("ascii") if isinstance(value, bytes) else value
self._dict = parse_qs(value, keep_blank_values=True)
self._dict = parse_qs(value, keep_blank_values=True) # type: ignore[assignment]
elif isinstance(value, QueryParams):
self._dict = {k: list(v) for k, v in value._dict.items()}
else:
Expand All @@ -456,7 +467,7 @@ def __init__(self, *args: QueryParamTypes | None, **kwargs: typing.Any) -> None:
# We coerce values `True` and `False` to JSON-like "true" and "false"
# representations, and coerce `None` values to the empty string.
self._dict = {
str(k): [primitive_value_to_str(item) for item in v]
str(k): [encode_query_value(item) for item in v]
for k, v in dict_value.items()
}

Expand All @@ -471,7 +482,7 @@ def keys(self) -> typing.KeysView[str]:
"""
return self._dict.keys()

def values(self) -> typing.ValuesView[str]:
def values(self) -> typing.ValuesView[str | bytes]:
"""
Return all the values in the query params. If a key occurs more than once
only the first item for that key is returned.
Expand All @@ -483,7 +494,7 @@ def values(self) -> typing.ValuesView[str]:
"""
return {k: v[0] for k, v in self._dict.items()}.values()

def items(self) -> typing.ItemsView[str, str]:
def items(self) -> typing.ItemsView[str, str | bytes]:
"""
Return all items in the query params. If a key occurs more than once
only the first item for that key is returned.
Expand All @@ -495,7 +506,7 @@ def items(self) -> typing.ItemsView[str, str]:
"""
return {k: v[0] for k, v in self._dict.items()}.items()

def multi_items(self) -> list[tuple[str, str]]:
def multi_items(self) -> list[tuple[str, str | bytes]]:
"""
Return all items in the query params. Allow duplicate keys to occur.

Expand All @@ -504,7 +515,7 @@ def multi_items(self) -> list[tuple[str, str]]:
q = httpx.QueryParams("a=123&a=456&b=789")
assert list(q.multi_items()) == [("a", "123"), ("a", "456"), ("b", "789")]
"""
multi_items: list[tuple[str, str]] = []
multi_items: list[tuple[str, str | bytes]] = []
for k, v in self._dict.items():
multi_items.extend([(k, i) for i in v])
return multi_items
Expand All @@ -523,7 +534,7 @@ def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
return self._dict[str(key)][0]
return default

def get_list(self, key: str) -> list[str]:
def get_list(self, key: str) -> list[str | bytes]:
"""
Get all values from the query param for a given key.

Expand All @@ -545,8 +556,8 @@ def set(self, key: str, value: typing.Any = None) -> QueryParams:
assert q == httpx.QueryParams("a=456")
"""
q = QueryParams()
q._dict = dict(self._dict)
q._dict[str(key)] = [primitive_value_to_str(value)]
q._dict = copy.deepcopy(self._dict)
q._dict[str(key)] = [encode_query_value(value)]
return q

def add(self, key: str, value: typing.Any = None) -> QueryParams:
Expand All @@ -560,8 +571,8 @@ def add(self, key: str, value: typing.Any = None) -> QueryParams:
assert q == httpx.QueryParams("a=123&a=456")
"""
q = QueryParams()
q._dict = dict(self._dict)
q._dict[str(key)] = q.get_list(key) + [primitive_value_to_str(value)]
q._dict = copy.deepcopy(self._dict)
q._dict[str(key)] = q.get_list(key) + [encode_query_value(value)]
return q

def remove(self, key: str) -> QueryParams:
Expand All @@ -575,7 +586,7 @@ def remove(self, key: str) -> QueryParams:
assert q == httpx.QueryParams("")
"""
q = QueryParams()
q._dict = dict(self._dict)
q._dict = copy.deepcopy(self._dict)
q._dict.pop(str(key), None)
return q

Expand All @@ -597,7 +608,7 @@ def merge(self, params: QueryParamTypes | None = None) -> QueryParams:
q._dict = {**self._dict, **q._dict}
return q

def __getitem__(self, key: typing.Any) -> str:
def __getitem__(self, key: typing.Any) -> str | bytes:
return self._dict[key][0]

def __contains__(self, key: typing.Any) -> bool:
Expand Down Expand Up @@ -646,3 +657,12 @@ def __setitem__(self, key: str, value: str) -> None:
"QueryParams are immutable since 0.18.0. "
"Use `q = q.set(key, value)` to create an updated copy."
)


if typing.TYPE_CHECKING: # pragma: no cover
# assert typing error
QueryParams("q=a", {"q": "a"}) # type: ignore[call-overload]
QueryParams({"a": 1}, {"q": "a"}) # type: ignore[call-overload]
QueryParams("q=a")
QueryParams({"q": "a"})
QueryParams(q="a")
15 changes: 14 additions & 1 deletion httpx/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
if typing.TYPE_CHECKING: # pragma: no cover
from ._urls import URL


_HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"}
_HTML5_FORM_ENCODING_REPLACEMENTS.update(
{chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B}
Expand Down Expand Up @@ -68,6 +67,20 @@ def primitive_value_to_str(value: PrimitiveData) -> str:
return str(value)


def encode_query_value(value: typing.Any) -> str | bytes:
if isinstance(value, (str, bytes)):
return value
if value is True:
return "true"
if value is False:
return "false"
if value is None:
return ""
if isinstance(value, (int, float)):
return str(value)
raise TypeError(f"can't use {type(value)!r} as query value")


def is_known_encoding(encoding: str) -> bool:
"""
Return `True` if `encoding` is a known codec.
Expand Down
13 changes: 13 additions & 0 deletions tests/models/test_queryparams.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import re
import types

import pytest

import httpx
Expand Down Expand Up @@ -134,3 +137,13 @@ def test_queryparams_are_hashable():
)

assert len(set(params)) == 2


def test_queryparams_bytes():
q = httpx.QueryParams({"q": bytes.fromhex("E1EE0E2734986F5419BB6C")})
assert str(q) == "q=%E1%EE%0E%274%98oT%19%BBl"


def test_queryparams_error():
with pytest.raises(TypeError, match=re.compile(r"can't use .* as query value")):
httpx.QueryParams({"q": types.SimpleNamespace()}) # type: ignore
8 changes: 4 additions & 4 deletions tests/test_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ async def test_urlencoded_content():

@pytest.mark.anyio
async def test_urlencoded_boolean():
request = httpx.Request(method, url, data={"example": True})
request = httpx.Request(method, url, data={"example": True, "e2": False})
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for coverage

assert isinstance(request.stream, typing.Iterable)
assert isinstance(request.stream, typing.AsyncIterable)

Expand All @@ -209,11 +209,11 @@ async def test_urlencoded_boolean():

assert request.headers == {
"Host": "www.example.com",
"Content-Length": "12",
"Content-Length": "21",
"Content-Type": "application/x-www-form-urlencoded",
}
assert sync_content == b"example=true"
assert async_content == b"example=true"
assert sync_content == b"example=true&e2=false"
assert async_content == b"example=true&e2=false"


@pytest.mark.anyio
Expand Down