Skip to content

Commit

Permalink
fix typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
CarliJoy committed May 17, 2021
1 parent 3ca356a commit 7b7f4ef
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 22 deletions.
14 changes: 7 additions & 7 deletions src/jinja2/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from .utils import _PassArg
from .utils import concat
from .utils import consume
from .utils import EscapeFunc
from .utils import get_wrapped_escape_class
from .utils import import_string
from .utils import internalcode
Expand Down Expand Up @@ -346,7 +347,7 @@ class Environment:
context_class: t.Type[Context] = Context

template_class: t.Type["Template"]
default_markup_class: t.Type["Markup"]
default_markup_class: t.Type[Markup]

def __init__(
self,
Expand All @@ -372,7 +373,7 @@ def __init__(
auto_reload: bool = True,
bytecode_cache: t.Optional["BytecodeCache"] = None,
enable_async: bool = False,
default_escape=html_escape,
default_escape: t.Union[EscapeFunc, t.Type[Markup]] = html_escape,
allow_mixed_escape_extends: bool = False,
):
# !!Important notice!!
Expand Down Expand Up @@ -406,6 +407,7 @@ def __init__(
self.finalize = finalize
self.autoescape = autoescape
if isclass(default_escape):
default_escape = t.cast(t.Type[Markup], default_escape)
self.default_markup_class = default_escape
elif default_escape != html_escape:
self.default_markup_class = get_wrapped_escape_class(default_escape)
Expand Down Expand Up @@ -434,9 +436,7 @@ def __init__(
self.is_async = enable_async
_environment_config_check(self)

def get_markup_class(
self, template_name: t.Optional[str] = None
) -> t.Type["Markup"]:
def get_markup_class(self, template_name: t.Optional[str] = None) -> t.Type[Markup]:
"""
Get the correct :class:`Markup` for the given template name.
Expand Down Expand Up @@ -1193,14 +1193,14 @@ def select_template(

for name in names:
if isinstance(name, Template):
self._check_multi_template_autoescape(names, parent_template, caller)
self._check_multi_template_autoescape(name, parent_template, caller)
return name
if parent is not None:
name = self.join_path(name, parent)
try:
template = self._load_template(name, globals)
# Only check autoescape if template can be loaded
self._check_multi_template_autoescape(names, parent_template, caller)
self._check_multi_template_autoescape(name, parent_template, caller)
return template
except (TemplateNotFound, UndefinedError):
pass
Expand Down
4 changes: 4 additions & 0 deletions src/jinja2/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def _make_new_gettext(func: t.Callable[[str], str]) -> t.Callable[..., str]:
def gettext(__context: Context, __string: str, **variables: t.Any) -> str:
rv = __context.call(func, __string)
if __context.eval_ctx.autoescape:
rv = t.cast(str, rv)
rv = __context.eval_ctx.mark_safe(rv)
# Always treat as a format string, even if there are no
# variables. This makes translation strings more consistent
Expand All @@ -192,6 +193,7 @@ def ngettext(
variables.setdefault("num", __num)
rv = __context.call(func, __singular, __plural, __num)
if __context.eval_ctx.autoescape:
rv = t.cast(str, rv)
rv = __context.eval_ctx.mark_safe(rv)
# Always treat as a format string, see gettext comment above.
return rv % variables # type: ignore
Expand All @@ -208,6 +210,7 @@ def pgettext(
rv = __context.call(func, __string_ctx, __string)

if __context.eval_ctx.autoescape:
rv = t.cast(str, rv)
rv = __context.eval_ctx.mark_safe(rv)

# Always treat as a format string, see gettext comment above.
Expand All @@ -233,6 +236,7 @@ def npgettext(
rv = __context.call(func, __string_ctx, __singular, __plural, __num)

if __context.eval_ctx.autoescape:
rv = t.cast(str, rv)
rv = __context.eval_ctx.mark_safe(rv)

# Always treat as a format string, see gettext comment above.
Expand Down
2 changes: 1 addition & 1 deletion src/jinja2/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def do_replace(
):
s = do_escape(eval_ctx, s)
else:
s = markupsafe.soft_str(s) # type: ignore
s = markupsafe.soft_str(s)

# Special case, if user uses Markup class directly to mark
# something as safe but uses custom escape function
Expand Down
9 changes: 5 additions & 4 deletions src/jinja2/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .nodes import EvalContext
from .utils import _PassArg
from .utils import concat
from .utils import EscapeFunc
from .utils import internalcode
from .utils import missing
from .utils import Namespace # noqa: F401
Expand Down Expand Up @@ -71,7 +72,7 @@ def identity(x: V) -> V:
return x


def markup_join(seq: t.Iterable[t.Any], escape_func=html_escape) -> str:
def markup_join(seq: t.Iterable[t.Any], escape_func: EscapeFunc = html_escape) -> str:
"""
Concatenation that escapes if necessary and converts to string.
Expand All @@ -88,7 +89,7 @@ def markup_join(seq: t.Iterable[t.Any], escape_func=html_escape) -> str:
return concat(buf)


def str_join(seq: t.Iterable[t.Any], escape_func=html_escape):
def str_join(seq: t.Iterable[t.Any], escape_func: EscapeFunc = html_escape) -> str:
"""
Simple args to string conversion and concatenation.
Expand Down Expand Up @@ -727,7 +728,7 @@ def __init__(
default_autoescape: t.Optional[bool] = None,
):
self._environment = environment
self._mark_safe = environment.default_markup_class
self._mark_safe: EscapeFunc = environment.get_markup_class()
self._func = func
self._argument_count = len(arguments)
self.name = name
Expand Down Expand Up @@ -770,7 +771,7 @@ def __call__(self, *args: t.Any, **kwargs: t.Any) -> str:
# If the eval context is available we use it to determine
# the correct mark safe method
# otherwise mark safe is already set in the __init__
# function from enviromental context
# function from environmental context
self._mark_safe = args[0].mark_safe
args = args[1:]
else:
Expand Down
23 changes: 13 additions & 10 deletions src/jinja2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import typing_extensions as te

F = t.TypeVar("F", bound=t.Callable[..., t.Any])
# Typing definition of the Escape function
EscapeFunc = t.Callable[[t.Any], markupsafe.Markup]

# special singleton representing missing values for the runtime
missing: t.Any = type("MissingType", (), {"__repr__": lambda x: "missing"})()
Expand Down Expand Up @@ -660,10 +662,10 @@ def __reversed__(self) -> t.Iterator[t.Any]:
def select_autoescape(
enabled_extensions: t.Collection[str] = ("html", "htm", "xml"),
disabled_extensions: t.Collection[str] = (),
special_extensions: t.Optional[t.Dict[str, t.Callable]] = None,
special_extensions: t.Optional[t.Dict[str, EscapeFunc]] = None,
default_for_string: bool = True,
default: bool = False,
) -> t.Callable[[t.Optional[str]], bool]:
) -> t.Callable[[t.Optional[str]], t.Union[bool, EscapeFunc]]:
"""Intelligently sets the initial value of autoescaping based on the
filename of the template. This is the recommended way to configure
autoescaping if you do not want to write a custom function yourself.
Expand Down Expand Up @@ -716,7 +718,7 @@ def select_autoescape(
parameter ``special_extensions`` was added
"""

def extension_str(x):
def extension_str(x: str) -> str:
"""return a lower case extension always starting with point"""
return f".{x.lstrip('.').lower()}"

Expand All @@ -725,22 +727,23 @@ def extension_str(x):

if special_extensions is None:
special_extensions = {}
if special_extensions is False:
special_extensions = {}
special_extensions = {
extension_str(key): func for key, func in special_extensions.items()
}

def autoescape(template_name: t.Optional[str]) -> bool:
def autoescape(template_name: t.Optional[str]) -> t.Union[bool, EscapeFunc]:
if template_name is None:
return default_for_string
template_name = template_name.lower()
# Lookup autoescape function using the longest matching suffix

for key, func in sorted(
special_extensions.items(), key=lambda x: len(x[0]), reverse=True
special_extensions.items(), # type: ignore
key=lambda x: len(x[0]),
reverse=True,
):
if template_name.endswith(key):
return func
return t.cast(EscapeFunc, func)
if template_name.endswith(enabled_patterns):
return True
if template_name.endswith(disabled_patterns):
Expand Down Expand Up @@ -826,12 +829,12 @@ class MarkupWrapper(markupsafe.Markup):
"""

@classmethod
def get_unwrapped_escape(cls):
def get_unwrapped_escape(cls) -> t.Callable[[Any], str]:
# Needed for test
return custom_escape

@classmethod
def escape(cls, s):
def escape(cls, s: Any) -> markupsafe.Markup:
"""
Make sure the custom escape function does not escape
already escaped strings
Expand Down

0 comments on commit 7b7f4ef

Please sign in to comment.