Skip to content
"""Parse tokens from the lexer into nodes for the compiler."""
import typing
import typing as t
from . import nodes
from .exceptions import TemplateAssertionError
from .exceptions import TemplateSyntaxError
from .lexer import describe_token
from .lexer import describe_token_expr
if t.TYPE_CHECKING:
import typing_extensions as te
from .environment import Environment
_ImportInclude = t.TypeVar("_ImportInclude", nodes.Import, nodes.Include)
_MacroCall = t.TypeVar("_MacroCall", nodes.Macro, nodes.CallBlock)
_statement_keywords = frozenset(
[
"for",
"if",
"block",
"extends",
"print",
"macro",
"include",
"from",
"import",
"set",
"with",
"autoescape",
]
)
_compare_operators = frozenset(["eq", "ne", "lt", "lteq", "gt", "gteq"])
_math_nodes: t.Dict[str, t.Type[nodes.Expr]] = {
"add": nodes.Add,
"sub": nodes.Sub,
"mul": nodes.Mul,
"div": nodes.Div,
"floordiv": nodes.FloorDiv,
"mod": nodes.Mod,
}
class Parser:
"""This is the central parsing class Jinja uses. It's passed to
extensions and can be used to parse expressions or statements.
"""
def __init__(
self,
environment: "Environment",
source: str,
name: t.Optional[str] = None,
filename: t.Optional[str] = None,
state: t.Optional[str] = None,
) -> None:
self.environment = environment
self.stream = environment._tokenize(source, name, filename, state)
self.name = name
self.filename = filename
self.closed = False
self.extensions: t.Dict[
str, t.Callable[["Parser"], t.Union[nodes.Node, t.List[nodes.Node]]]
] = {}
for extension in environment.iter_extensions():
for tag in extension.tags:
self.extensions[tag] = extension.parse
self._last_identifier = 0
self._tag_stack: t.List[str] = []
self._end_token_stack: t.List[t.Tuple[str, ...]] = []
def fail(
self,
msg: str,
lineno: t.Optional[int] = None,
exc: t.Type[TemplateSyntaxError] = TemplateSyntaxError,
) -> "te.NoReturn":
"""Convenience method that raises `exc` with the message, passed
line number or last line number as well as the current name and
filename.
"""
if lineno is None:
lineno = self.stream.current.lineno
raise exc(msg, lineno, self.name, self.filename)
def _fail_ut_eof(
self,
name: t.Optional[str],
end_token_stack: t.List[t.Tuple[str, ...]],
lineno: t.Optional[int],
) -> "te.NoReturn":
expected: t.Set[str] = set()
for exprs in end_token_stack:
expected.update(map(describe_token_expr, exprs))
if end_token_stack:
currently_looking: t.Optional[str] = " or ".join(
map(repr, map(describe_token_expr, end_token_stack[-1]))
)
else:
currently_looking = None
if name is None:
message = ["Unexpected end of template."]
else:
message = [f"Encountered unknown tag {name!r}."]
if currently_looking:
if name is not None and name in expected:
message.append(
"You probably made a nesting mistake. Jinja is expecting this tag,"
f" but currently looking for {currently_looking}."
)
else:
message.append(
f"Jinja was looking for the following tags: {currently_looking}."
)
if self._tag_stack:
message.append(
"The innermost block that needs to be closed is"
f" {self._tag_stack[-1]!r}."
)
self.fail(" ".join(message), lineno)
def fail_unknown_tag(
self, name: str, lineno: t.Optional[int] = None
) -> "te.NoReturn":
"""Called if the parser encounters an unknown tag. Tries to fail
with a human readable error message that could help to identify
the problem.
"""
self._fail_ut_eof(name, self._end_token_stack, lineno)
def fail_eof(
self,
end_tokens: t.Optional[t.Tuple[str, ...]] = None,
lineno: t.Optional[int] = None,
) -> "te.NoReturn":
"""Like fail_unknown_tag but for end of template situations."""
stack = list(self._end_token_stack)
if end_tokens is not None:
stack.append(end_tokens)
self._fail_ut_eof(None, stack, lineno)
def is_tuple_end(
self, extra_end_rules: t.Optional[t.Tuple[str, ...]] = None
) -> bool:
"""Are we at the end of a tuple?"""
if self.stream.current.type in ("variable_end", "block_end", "rparen"):
return True
elif extra_end_rules is not None:
return self.stream.current.test_any(extra_end_rules) # type: ignore
return False
def free_identifier(self, lineno: t.Optional[int] = None) -> nodes.InternalName:
"""Return a new free identifier as :class:`~jinja2.nodes.InternalName`."""
self._last_identifier += 1
rv = object.__new__(nodes.InternalName)
nodes.Node.__init__(rv, f"fi{self._last_identifier}", lineno=lineno)
return rv
def parse_statement(self) -> t.Union[nodes.Node, t.List[nodes.Node]]:
"""Parse a single statement."""
token = self.stream.current
if token.type != "name":
self.fail("tag name expected", token.lineno)
self._tag_stack.append(token.value)
pop_tag = True
try:
if token.value in _statement_keywords:
f = getattr(self, f"parse_{self.stream.current.value}")
return f() # type: ignore
if token.value == "call":
return self.parse_call_block()
if token.value == "filter":
return self.parse_filter_block()
ext = self.extensions.get(token.value)
if ext is not None:
return ext(self)
# did not work out, remove the token we pushed by accident
# from the stack so that the unknown tag fail function can
# produce a proper error message.
self._tag_stack.pop()
pop_tag = False
self.fail_unknown_tag(token.value, token.lineno)
finally:
if pop_tag:
self._tag_stack.pop()
def parse_statements(
self, end_tokens: t.Tuple[str, ...], drop_needle: bool = False
) -> t.List[nodes.Node]:
"""Parse multiple statements into a list until one of the end tokens
is reached. This is used to parse the body of statements as it also
parses template data if appropriate. The parser checks first if the
current token is a colon and skips it if there is one. Then it checks
for the block end and parses until if one of the `end_tokens` is
reached. Per default the active token in the stream at the end of
the call is the matched end token. If this is not wanted `drop_needle`
can be set to `True` and the end token is removed.
"""
# the first token may be a colon for python compatibility
self.stream.skip_if("colon")
# in the future it would be possible to add whole code sections
# by adding some sort of end of statement token and parsing those here.
self.stream.expect("block_end")
result = self.subparse(end_tokens)
# we reached the end of the template too early, the subparser
# does not check for this, so we do that now
if self.stream.current.type == "eof":
self.fail_eof(end_tokens)
if drop_needle:
next(self.stream)
return result
def parse_set(self) -> t.Union[nodes.Assign, nodes.AssignBlock]:
"""Parse an assign statement."""
lineno = next(self.stream).lineno
target = self.parse_assign_target(with_namespace=True)
if self.stream.skip_if("assign"):
expr = self.parse_tuple()
return nodes.Assign(target, expr, lineno=lineno)
filter_node = self.parse_filter(None)
body = self.parse_statements(("name:endset",), drop_needle=True)
return nodes.AssignBlock(target, filter_node, body, lineno=lineno)
def parse_for(self) -> nodes.For:
"""Parse a for loop."""
lineno = self.stream.expect("name:for").lineno
target = self.parse_assign_target(extra_end_rules=("name:in",))
self.stream.expect("name:in")
iter = self.parse_tuple(
with_condexpr=False, extra_end_rules=("name:recursive",)
)
test = None
if self.stream.skip_if("name:if"):
test = self.parse_expression()
recursive = self.stream.skip_if("name:recursive")
body = self.parse_statements(("name:endfor", "name:else"))
if next(self.stream).value == "endfor":
else_ = []
else:
else_ = self.parse_statements(("name:endfor",), drop_needle=True)
return nodes.For(target, iter, body, else_, test, recursive, lineno=lineno)
def parse_if(self) -> nodes.If:
"""Parse an if construct."""
node = result = nodes.If(lineno=self.stream.expect("name:if").lineno)
while True:
node.test = self.parse_tuple(with_condexpr=False)
node.body = self.parse_statements(("name:elif", "name:else", "name:endif"))
node.elif_ = []
node.else_ = []
token = next(self.stream)
if token.test("name:elif"):
node = nodes.If(lineno=self.stream.current.lineno)
result.elif_.append(node)
continue
elif token.test("name:else"):
result.else_ = self.parse_statements(("name:endif",), drop_needle=True)
break
return result
def parse_with(self) -> nodes.With:
node = nodes.With(lineno=next(self.stream).lineno)
targets: t.List[nodes.Expr] = []
values: t.List[nodes.Expr] = []
while self.stream.current.type != "block_end":
if targets:
self.stream.expect("comma")
target = self.parse_assign_target()
target.set_ctx("param")
targets.append(target)
self.stream.expect("assign")
values.append(self.parse_expression())
node.targets = targets
node.values = values
node.body = self.parse_statements(("name:endwith",), drop_needle=True)
return node
def parse_autoescape(self) -> nodes.Scope:
node = nodes.ScopedEvalContextModifier(lineno=next(self.stream).lineno)
node.options = [nodes.Keyword("autoescape", self.parse_expression())]
node.body = self.parse_statements(("name:endautoescape",), drop_needle=True)
return nodes.Scope([node])
def parse_block(self) -> nodes.Block:
node = nodes.Block(lineno=next(self.stream).lineno)
node.name = self.stream.expect("name").value
node.scoped = self.stream.skip_if("name:scoped")
node.required = self.stream.skip_if("name:required")
# common problem people encounter when switching from django
# to jinja. we do not support hyphens in block names, so let's
# raise a nicer error message in that case.
if self.stream.current.type == "sub":
self.fail(
"Block names in Jinja have to be valid Python identifiers and may not"
" contain hyphens, use an underscore instead."
)
node.body = self.parse_statements(("name:endblock",), drop_needle=True)
# enforce that required blocks only contain whitespace or comments
# by asserting that the body, if not empty, is just TemplateData nodes
# with whitespace data
if node.required:
for body_node in node.body:
if not isinstance(body_node, nodes.Output) or any(
not isinstance(output_node, nodes.TemplateData)
or not output_node.data.isspace()
for output_node in body_node.nodes
):
self.fail("Required blocks can only contain comments or whitespace")
self.stream.skip_if("name:" + node.name)
return node
def parse_extends(self) -> nodes.Extends:
node = nodes.Extends(lineno=next(self.stream).lineno)
node.template = self.parse_expression()
return node
def parse_import_context(
self, node: _ImportInclude, default: bool
) -> _ImportInclude:
if self.stream.current.test_any(
"name:with", "name:without"
) and self.stream.look().test("name:context"):
node.with_context = next(self.stream).value == "with"
self.stream.skip()
else:
node.with_context = default
return node
def parse_include(self) -> nodes.Include:
node = nodes.Include(lineno=next(self.stream).lineno)
node.template = self.parse_expression()
if self.stream.current.test("name:ignore") and self.stream.look().test(
"name:missing"
):
node.ignore_missing = True
self.stream.skip(2)
else:
node.ignore_missing = False
return self.parse_import_context(node, True)
def parse_import(self) -> nodes.Import:
node = nodes.Import(lineno=next(self.stream).lineno)
node.template = self.parse_expression()
self.stream.expect("name:as")
node.target = self.parse_assign_target(name_only=True).name
return self.parse_import_context(node, False)
def parse_from(self) -> nodes.FromImport:
node = nodes.FromImport(lineno=next(self.stream).lineno)
node.template = self.parse_expression()
self.stream.expect("name:import")
node.names = []
def parse_context() -> bool:
if self.stream.current.value in {
"with",
"without",
} and self.stream.look().test("name:context"):
node.with_context = next(self.stream).value == "with"
self.stream.skip()
return True
return False
while True:
if node.names:
self.stream.expect("comma")
if self.stream.current.type == "name":
if parse_context():
break
target = self.parse_assign_target(name_only=True)
if target.name.startswith("_"):
self.fail(
"names starting with an underline can not be imported",
target.lineno,
exc=TemplateAssertionError,
)
if self.stream.skip_if("name:as"):
alias = self.parse_assign_target(name_only=True)
node.names.append((target.name, alias.name))
else:
node.names.append(target.name)
if parse_context() or self.stream.current.type != "comma":
break
else:
self.stream.expect("name")
if not hasattr(node, "with_context"):
node.with_context = False
return node
def parse_signature(self, node: _MacroCall) -> None:
args = node.args = []
defaults = node.defaults = []
self.stream.expect("lparen")
while self.stream.current.type != "rparen":
if args:
self.stream.expect("comma")
arg = self.parse_assign_target(name_only=True)
arg.set_ctx("param")
if self.stream.skip_if("assign"):
defaults.append(self.parse_expression())
elif defaults:
self.fail("non-default argument follows default argument")
args.append(arg)
self.stream.expect("rparen")
def parse_call_block(self) -> nodes.CallBlock:
node = nodes.CallBlock(lineno=next(self.stream).lineno)
if self.stream.current.type == "lparen":
self.parse_signature(node)
else:
node.args = []
node.defaults = []
call_node = self.parse_expression()
if not isinstance(call_node, nodes.Call):
self.fail("expected call", node.lineno)
node.call = call_node
node.body = self.parse_statements(("name:endcall",), drop_needle=True)
return node
def parse_filter_block(self) -> nodes.FilterBlock:
node = nodes.FilterBlock(lineno=next(self.stream).lineno)
node.filter = self.parse_filter(None, start_inline=True) # type: ignore
node.body = self.parse_statements(("name:endfilter",), drop_needle=True)
return node
def parse_macro(self) -> nodes.Macro:
node = nodes.Macro(lineno=next(self.stream).lineno)
node.name = self.parse_assign_target(name_only=True).name
self.parse_signature(node)
node.body = self.parse_statements(("name:endmacro",), drop_needle=True)
return node
def parse_print(self) -> nodes.Output:
node = nodes.Output(lineno=next(self.stream).lineno)
node.nodes = []
while self.stream.current.type != "block_end":
if node.nodes:
self.stream.expect("comma")
node.nodes.append(self.parse_expression())
return node
@typing.overload
def parse_assign_target(
self, with_tuple: bool = ..., name_only: "te.Literal[True]" = ...
) -> nodes.Name:
...
@typing.overload
def parse_assign_target(
self,
with_tuple: bool = True,
name_only: bool = False,
extra_end_rules: t.Optional[t.Tuple[str, ...]] = None,
with_namespace: bool = False,
) -> t.Union[nodes.NSRef, nodes.Name, nodes.Tuple]:
...
def parse_assign_target(
self,
with_tuple: bool = True,
name_only: bool = False,
extra_end_rules: t.Optional[t.Tuple[str, ...]] = None,
with_namespace: bool = False,
) -> t.Union[nodes.NSRef, nodes.Name, nodes.Tuple]:
"""Parse an assignment target. As Jinja allows assignments to
tuples, this function can parse all allowed assignment targets. Per
default assignments to tuples are parsed, that can be disable however
by setting `with_tuple` to `False`. If only assignments to names are
wanted `name_only` can be set to `True`. The `extra_end_rules`
parameter is forwarded to the tuple parsing function. If
`with_namespace` is enabled, a namespace assignment may be parsed.
"""
target: nodes.Expr
if with_namespace and self.stream.look().type == "dot":
token = self.stream.expect("name")
next(self.stream) # dot
attr = self.stream.expect("name")
target = nodes.NSRef(token.value, attr.value, lineno=token.lineno)
elif name_only:
token = self.stream.expect("name")
target = nodes.Name(token.value, "store", lineno=token.lineno)
else:
if with_tuple:
target = self.parse_tuple(
simplified=True, extra_end_rules=extra_end_rules
)
else:
target = self.parse_primary()
target.set_ctx("store")
if not target.can_assign():
self.fail(
f"can't assign to {type(target).__name__.lower()!r}", target.lineno
)
return target # type: ignore
def parse_expression(self, with_condexpr: bool = True) -> nodes.Expr:
"""Parse an expression. Per default all expressions are parsed, if
the optional `with_condexpr` parameter is set to `False` conditional
expressions are not parsed.
"""
if with_condexpr:
return self.parse_condexpr()
return self.parse_or()
def parse_condexpr(self) -> nodes.Expr:
lineno = self.stream.current.lineno
expr1 = self.parse_or()
expr3: t.Optional[nodes.Expr]
while self.stream.skip_if("name:if"):
expr2 = self.parse_or()
if self.stream.skip_if("name:else"):
expr3 = self.parse_condexpr()
else:
expr3 = None
expr1 = nodes.CondExpr(expr2, expr1, expr3, lineno=lineno)
lineno = self.stream.current.lineno
return expr1
def parse_or(self) -> nodes.Expr:
lineno = self.stream.current.lineno
left = self.parse_and()
while self.stream.skip_if("name:or"):
right = self.parse_and()
left = nodes.Or(left, right, lineno=lineno)
lineno = self.stream.current.lineno
return left
def parse_and(self) -> nodes.Expr:
lineno = self.stream.current.lineno
left = self.parse_not()
while self.stream.skip_if("name:and"):
right = self.parse_not()
left = nodes.And(left, right, lineno=lineno)
lineno = self.stream.current.lineno
return left
def parse_not(self) -> nodes.Expr:
if self.stream.current.test("name:not"):
lineno = next(self.stream).lineno
return nodes.Not(self.parse_not(), lineno=lineno)
return self.parse_compare()
def parse_compare(self) -> nodes.Expr:
lineno = self.stream.current.lineno
expr = self.parse_math1()
ops = []
while True:
token_type = self.stream.current.type
if token_type in _compare_operators:
next(self.stream)
ops.append(nodes.Operand(token_type, self.parse_math1()))
elif self.stream.skip_if("name:in"):
ops.append(nodes.Operand("in", self.parse_math1()))
elif self.stream.current.test("name:not") and self.stream.look().test(
"name:in"
):
self.stream.skip(2)
ops.append(nodes.Operand("notin", self.parse_math1()))
else:
break
lineno = self.stream.current.lineno
if not ops:
return expr
return nodes.Compare(expr, ops, lineno=lineno)
def parse_math1(self) -> nodes.Expr:
lineno = self.stream.current.lineno
left = self.parse_concat()
while self.stream.current.type in ("add", "sub"):
cls = _math_nodes[self.stream.current.type]
next(self.stream)
right = self.parse_concat()
left = cls(left, right, lineno=lineno)
lineno = self.stream.current.lineno
return left
def parse_concat(self) -> nodes.Expr:
lineno = self.stream.current.lineno
args = [self.parse_math2()]
while self.stream.current.type == "tilde":
next(self.stream)
args.append(self.parse_math2())
if len(args) == 1:
return args[0]
return nodes.Concat(args, lineno=lineno)
def parse_math2(self) -> nodes.Expr:
lineno = self.stream.current.lineno
left = self.parse_pow()
while self.stream.current.type in ("mul", "div", "floordiv", "mod"):
cls = _math_nodes[self.stream.current.type]
next(self.stream)
right = self.parse_pow()
left = cls(left, right, lineno=lineno)
lineno = self.stream.current.lineno
return left
def parse_pow(self) -> nodes.Expr:
lineno = self.stream.current.lineno
left = self.parse_unary()
while self.stream.current.type == "pow":
next(self.stream)
right = self.parse_unary()
left = nodes.Pow(left, right, lineno=lineno)
lineno = self.stream.current.lineno
return left
def parse_unary(self, with_filter: bool = True) -> nodes.Expr:
token_type = self.stream.current.type
lineno = self.stream.current.lineno
node: nodes.Expr
if token_type == "sub":
next(self.stream)
node = nodes.Neg(self.parse_unary(False), lineno=lineno)
elif token_type == "add":
next(self.stream)
node = nodes.Pos(self.parse_unary(False), lineno=lineno)
else:
node = self.parse_primary()
node = self.parse_postfix(node)
if with_filter:
node = self.parse_filter_expr(node)
return node
def parse_primary(self) -> nodes.Expr:
token = self.stream.current
node: nodes.Expr
if token.type == "name":
if token.value in ("true", "false", "True", "False"):
node = nodes.Const(token.value in ("true", "True"), lineno=token.lineno)
elif token.value in ("none", "None"):
node = nodes.Const(None, lineno=token.lineno)
else:
node = nodes.Name(token.value, "load", lineno=token.lineno)
next(self.stream)
elif token.type == "string":
next(self.stream)
buf = [token.value]
lineno = token.lineno
while self.stream.current.type == "string":
buf.append(self.stream.current.value)
next(self.stream)
node = nodes.Const("".join(buf), lineno=lineno)
elif token.type in ("integer", "float"):
next(self.stream)
node = nodes.Const(token.value, lineno=token.lineno)
elif token.type == "lparen":
next(self.stream)
node = self.parse_tuple(explicit_parentheses=True)
self.stream.expect("rparen")
elif token.type == "lbracket":
node = self.parse_list()
elif token.type == "lbrace":
node = self.parse_dict()
else:
self.fail(f"unexpected {describe_token(token)!r}", token.lineno)
return node
def parse_tuple(
self,
simplified: bool = False,
with_condexpr: bool = True,
extra_end_rules: t.Optional[t.Tuple[str, ...]] = None,
explicit_parentheses: bool = False,
) -> t.Union[nodes.Tuple, nodes.Expr]:
"""Works like `parse_expression` but if multiple expressions are
delimited by a comma a :class:`~jinja2.nodes.Tuple` node is created.
This method could also return a regular expression instead of a tuple
if no commas where found.
The default parsing mode is a full tuple. If `simplified` is `True`
only names and literals are parsed. The `no_condexpr` parameter is
forwarded to :meth:`parse_expression`.
Because tuples do not require delimiters and may end in a bogus comma
an extra hint is needed that marks the end of a tuple. For example
for loops support tuples between `for` and `in`. In that case the
`extra_end_rules` is set to ``['name:in']``.
`explicit_parentheses` is true if the parsing was triggered by an
expression in parentheses. This is used to figure out if an empty
tuple is a valid expression or not.
"""
lineno = self.stream.current.lineno
if simplified:
parse = self.parse_primary
elif with_condexpr:
parse = self.parse_expression
else:
def parse() -> nodes.Expr:
return self.parse_expression(with_condexpr=False)
args: t.List[nodes.Expr] = []
is_tuple = False
while True:
if args:
self.stream.expect("comma")
if self.is_tuple_end(extra_end_rules):
break
args.append(parse())
if self.stream.current.type == "comma":
is_tuple = True
else:
break
lineno = self.stream.current.lineno
if not is_tuple:
if args:
return args[0]
# if we don't have explicit parentheses, an empty tuple is
# not a valid expression. This would mean nothing (literally
# nothing) in the spot of an expression would be an empty
# tuple.
if not explicit_parentheses:
self.fail(
"Expected an expression,"
f" got {describe_token(self.stream.current)!r}"
)
return nodes.Tuple(args, "load", lineno=lineno)
def parse_list(self) -> nodes.List:
token = self.stream.expect("lbracket")
items: t.List[nodes.Expr] = []
while self.stream.current.type != "rbracket":
if items:
self.stream.expect("comma")
if self.stream.current.type == "rbracket":
break
items.append(self.parse_expression())
self.stream.expect("rbracket")
return nodes.List(items, lineno=token.lineno)
def parse_dict(self) -> nodes.Dict:
token = self.stream.expect("lbrace")
items: t.List[nodes.Pair] = []
while self.stream.current.type != "rbrace":
if items:
self.stream.expect("comma")
if self.stream.current.type == "rbrace":
break
key = self.parse_expression()
self.stream.expect("colon")
value = self.parse_expression()
items.append(nodes.Pair(key, value, lineno=key.lineno))
self.stream.expect("rbrace")
return nodes.Dict(items, lineno=token.lineno)
def parse_postfix(self, node: nodes.Expr) -> nodes.Expr:
while True:
token_type = self.stream.current.type
if token_type == "dot" or token_type == "lbracket":
node = self.parse_subscript(node)
# calls are valid both after postfix expressions (getattr
# and getitem) as well as filters and tests
elif token_type == "lparen":
node = self.parse_call(node)
else:
break
return node
def parse_filter_expr(self, node: nodes.Expr) -> nodes.Expr:
while True:
token_type = self.stream.current.type
if token_type == "pipe":
node = self.parse_filter(node) # type: ignore
elif token_type == "name" and self.stream.current.value == "is":
node = self.parse_test(node)
# calls are valid both after postfix expressions (getattr
# and getitem) as well as filters and tests
elif token_type == "lparen":
node = self.parse_call(node)
else:
break
return node
def parse_subscript(
self, node: nodes.Expr
) -> t.Union[nodes.Getattr, nodes.Getitem]:
token = next(self.stream)
arg: nodes.Expr
if token.type == "dot":
attr_token = self.stream.current
next(self.stream)
if attr_token.type == "name":
return nodes.Getattr(
node, attr_token.value, "load", lineno=token.lineno
)
elif attr_token.type != "integer":
self.fail("expected name or number", attr_token.lineno)
arg = nodes.Const(attr_token.value, lineno=attr_token.lineno)
return nodes.Getitem(node, arg, "load", lineno=token.lineno)
if token.type == "lbracket":
args: t.List[nodes.Expr] = []
while self.stream.current.type != "rbracket":
if args:
self.stream.expect("comma")
args.append(self.parse_subscribed())
self.stream.expect("rbracket")
if len(args) == 1:
arg = args[0]
else:
arg = nodes.Tuple(args, "load", lineno=token.lineno)
return nodes.Getitem(node, arg, "load", lineno=token.lineno)
self.fail("expected subscript expression", token.lineno)
def parse_subscribed(self) -> nodes.Expr:
lineno = self.stream.current.lineno
args: t.List[t.Optional[nodes.Expr]]
if self.stream.current.type == "colon":
next(self.stream)
args = [None]
else:
node = self.parse_expression()
if self.stream.current.type != "colon":
return node
next(self.stream)
args = [node]
if self.stream.current.type == "colon":
args.append(None)
elif self.stream.current.type not in ("rbracket", "comma"):
args.append(self.parse_expression())
else:
args.append(None)
if self.stream.current.type == "colon":
next(self.stream)
if self.stream.current.type not in ("rbracket", "comma"):
args.append(self.parse_expression())
else:
args.append(None)
else:
args.append(None)
return nodes.Slice(lineno=lineno, *args) # noqa: B026
def parse_call_args(self) -> t.Tuple:
token = self.stream.expect("lparen")
args = []
kwargs = []
dyn_args = None
dyn_kwargs = None
require_comma = False
def ensure(expr: bool) -> None:
if not expr:
self.fail("invalid syntax for function call expression", token.lineno)
while self.stream.current.type != "rparen":
if require_comma:
self.stream.expect("comma")
# support for trailing comma
if self.stream.current.type == "rparen":
break
if self.stream.current.type == "mul":
ensure(dyn_args is None and dyn_kwargs is None)
next(self.stream)
dyn_args = self.parse_expression()
elif self.stream.current.type == "pow":
ensure(dyn_kwargs is None)
next(self.stream)
dyn_kwargs = self.parse_expression()
else:
if (
self.stream.current.type == "name"
and self.stream.look().type == "assign"
):
# Parsing a kwarg
ensure(dyn_kwargs is None)
key = self.stream.current.value
self.stream.skip(2)
value = self.parse_expression()
kwargs.append(nodes.Keyword(key, value, lineno=value.lineno))
else:
# Parsing an arg
ensure(dyn_args is None and dyn_kwargs is None and not kwargs)
args.append(self.parse_expression())
require_comma = True
self.stream.expect("rparen")
return args, kwargs, dyn_args, dyn_kwargs
def parse_call(self, node: nodes.Expr) -> nodes.Call:
# The lparen will be expected in parse_call_args, but the lineno
# needs to be recorded before the stream is advanced.
token = self.stream.current
args, kwargs, dyn_args, dyn_kwargs = self.parse_call_args()
return nodes.Call(node, args, kwargs, dyn_args, dyn_kwargs, lineno=token.lineno)
def parse_filter(
self, node: t.Optional[nodes.Expr], start_inline: bool = False
) -> t.Optional[nodes.Expr]:
while self.stream.current.type == "pipe" or start_inline:
if not start_inline:
next(self.stream)
token = self.stream.expect("name")
name = token.value
while self.stream.current.type == "dot":
next(self.stream)
name += "." + self.stream.expect("name").value
if self.stream.current.type == "lparen":
args, kwargs, dyn_args, dyn_kwargs = self.parse_call_args()
else:
args = []
kwargs = []
dyn_args = dyn_kwargs = None
node = nodes.Filter(
node, name, args, kwargs, dyn_args, dyn_kwargs, lineno=token.lineno
)
start_inline = False
return node
def parse_test(self, node: nodes.Expr) -> nodes.Expr:
token = next(self.stream)
if self.stream.current.test("name:not"):
next(self.stream)
negated = True
else:
negated = False
name = self.stream.expect("name").value
while self.stream.current.type == "dot":
next(self.stream)
name += "." + self.stream.expect("name").value
dyn_args = dyn_kwargs = None
kwargs = []
if self.stream.current.type == "lparen":
args, kwargs, dyn_args, dyn_kwargs = self.parse_call_args()
elif self.stream.current.type in {
"name",
"string",
"integer",
"float",
"lparen",
"lbracket",
"lbrace",
} and not self.stream.current.test_any("name:else", "name:or", "name:and"):
if self.stream.current.test("name:is"):
self.fail("You cannot chain multiple tests with is")
arg_node = self.parse_primary()
arg_node = self.parse_postfix(arg_node)
args = [arg_node]
else:
args = []
node = nodes.Test(
node, name, args, kwargs, dyn_args, dyn_kwargs, lineno=token.lineno
)
if negated:
node = nodes.Not(node, lineno=token.lineno)
return node
def subparse(
self, end_tokens: t.Optional[t.Tuple[str, ...]] = None
) -> t.List[nodes.Node]:
body: t.List[nodes.Node] = []
data_buffer: t.List[nodes.Node] = []
add_data = data_buffer.append
if end_tokens is not None:
self._end_token_stack.append(end_tokens)
def flush_data() -> None:
if data_buffer:
lineno = data_buffer[0].lineno
body.append(nodes.Output(data_buffer[:], lineno=lineno))
del data_buffer[:]
try:
while self.stream:
token = self.stream.current
if token.type == "data":
if token.value:
add_data(nodes.TemplateData(token.value, lineno=token.lineno))
next(self.stream)
elif token.type == "variable_begin":
next(self.stream)
add_data(self.parse_tuple(with_condexpr=True))
self.stream.expect("variable_end")
elif token.type == "block_begin":
flush_data()
next(self.stream)
if end_tokens is not None and self.stream.current.test_any(
*end_tokens
):
return body
rv = self.parse_statement()
if isinstance(rv, list):
body.extend(rv)
else:
body.append(rv)
self.stream.expect("block_end")
else:
raise AssertionError("internal parsing error")
flush_data()
finally:
if end_tokens is not None:
self._end_token_stack.pop()
return body
def parse(self) -> nodes.Template:
"""Parse the whole template into a `Template` node."""
result = nodes.Template(self.subparse(), lineno=1)
result.set_environment(self.environment)
return result
"""The runtime functions and state used by compiled templates."""
import functools
import sys
import typing as t
from collections import abc
from itertools import chain
from markupsafe import escape # noqa: F401
from markupsafe import Markup
from markupsafe import soft_str
from .async_utils import auto_aiter
from .async_utils import auto_await # noqa: F401
from .exceptions import TemplateNotFound # noqa: F401
from .exceptions import TemplateRuntimeError # noqa: F401
from .exceptions import UndefinedError
from .nodes import EvalContext
from .utils import _PassArg
from .utils import concat
from .utils import internalcode
from .utils import missing
from .utils import Namespace # noqa: F401
from .utils import object_type_repr
from .utils import pass_eval_context
V = t.TypeVar("V")
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
if t.TYPE_CHECKING:
import logging
import typing_extensions as te
from .environment import Environment
class LoopRenderFunc(te.Protocol):
def __call__(
self,
reciter: t.Iterable[V],
loop_render_func: "LoopRenderFunc",
depth: int = 0,
) -> str:
...
# these variables are exported to the template runtime
exported = [
"LoopContext",
"TemplateReference",
"Macro",
"Markup",
"TemplateRuntimeError",
"missing",
"escape",
"markup_join",
"str_join",
"identity",
"TemplateNotFound",
"Namespace",
"Undefined",
"internalcode",
]
async_exported = [
"AsyncLoopContext",
"auto_aiter",
"auto_await",
]
def identity(x: V) -> V:
"""Returns its argument. Useful for certain things in the
environment.
"""
return x
def markup_join(seq: t.Iterable[t.Any]) -> str:
"""Concatenation that escapes if necessary and converts to string."""
buf = []
iterator = map(soft_str, seq)
for arg in iterator:
buf.append(arg)
if hasattr(arg, "__html__"):
return Markup("").join(chain(buf, iterator))
return concat(buf)
def str_join(seq: t.Iterable[t.Any]) -> str:
"""Simple args to string conversion and concatenation."""
return concat(map(str, seq))
def new_context(
environment: "Environment",
template_name: t.Optional[str],
blocks: t.Dict[str, t.Callable[["Context"], t.Iterator[str]]],
vars: t.Optional[t.Dict[str, t.Any]] = None,
shared: bool = False,
globals: t.Optional[t.MutableMapping[str, t.Any]] = None,
locals: t.Optional[t.Mapping[str, t.Any]] = None,
) -> "Context":
"""Internal helper for context creation."""
if vars is None:
vars = {}
if shared:
parent = vars
else:
parent = dict(globals or (), **vars)
if locals:
# if the parent is shared a copy should be created because
# we don't want to modify the dict passed
if shared:
parent = dict(parent)
for key, value in locals.items():
if value is not missing:
parent[key] = value
return environment.context_class(
environment, parent, template_name, blocks, globals=globals
)
class TemplateReference:
"""The `self` in templates."""
def __init__(self, context: "Context") -> None:
self.__context = context
def __getitem__(self, name: str) -> t.Any:
blocks = self.__context.blocks[name]
return BlockReference(name, self.__context, blocks, 0)
def __repr__(self) -> str:
return f"<{type(self).__name__} {self.__context.name!r}>"
def _dict_method_all(dict_method: F) -> F:
@functools.wraps(dict_method)
def f_all(self: "Context") -> t.Any:
return dict_method(self.get_all())
return t.cast(F, f_all)
@abc.Mapping.register
class Context:
"""The template context holds the variables of a template. It stores the
values passed to the template and also the names the template exports.
Creating instances is neither supported nor useful as it's created
automatically at various stages of the template evaluation and should not
be created by hand.
The context is immutable. Modifications on :attr:`parent` **must not**
happen and modifications on :attr:`vars` are allowed from generated
template code only. Template filters and global functions marked as
:func:`pass_context` get the active context passed as first argument
and are allowed to access the context read-only.
The template context supports read only dict operations (`get`,
`keys`, `values`, `items`, `iterkeys`, `itervalues`, `iteritems`,
`__getitem__`, `__contains__`). Additionally there is a :meth:`resolve`
method that doesn't fail with a `KeyError` but returns an
:class:`Undefined` object for missing variables.
"""
def __init__(
self,
environment: "Environment",
parent: t.Dict[str, t.Any],
name: t.Optional[str],
blocks: t.Dict[str, t.Callable[["Context"], t.Iterator[str]]],
globals: t.Optional[t.MutableMapping[str, t.Any]] = None,
):
self.parent = parent
self.vars: t.Dict[str, t.Any] = {}
self.environment: "Environment" = environment
self.eval_ctx = EvalContext(self.environment, name)
self.exported_vars: t.Set[str] = set()
self.name = name
self.globals_keys = set() if globals is None else set(globals)
# create the initial mapping of blocks. Whenever template inheritance
# takes place the runtime will update this mapping with the new blocks
# from the template.
self.blocks = {k: [v] for k, v in blocks.items()}
def super(
self, name: str, current: t.Callable[["Context"], t.Iterator[str]]
) -> t.Union["BlockReference", "Undefined"]:
"""Render a parent block."""
try:
blocks = self.blocks[name]
index = blocks.index(current) + 1
blocks[index]
except LookupError:
return self.environment.undefined(
f"there is no parent block called {name!r}.", name="super"
)
return BlockReference(name, self, blocks, index)
def get(self, key: str, default: t.Any = None) -> t.Any:
"""Look up a variable by name, or return a default if the key is
not found.
:param key: The variable name to look up.
:param default: The value to return if the key is not found.
"""
try:
return self[key]
except KeyError:
return default
def resolve(self, key: str) -> t.Union[t.Any, "Undefined"]:
"""Look up a variable by name, or return an :class:`Undefined`
object if the key is not found.
If you need to add custom behavior, override
:meth:`resolve_or_missing`, not this method. The various lookup
functions use that method, not this one.
:param key: The variable name to look up.
"""
rv = self.resolve_or_missing(key)
if rv is missing:
return self.environment.undefined(name=key)
return rv
def resolve_or_missing(self, key: str) -> t.Any:
"""Look up a variable by name, or return a ``missing`` sentinel
if the key is not found.
Override this method to add custom lookup behavior.
:meth:`resolve`, :meth:`get`, and :meth:`__getitem__` use this
method. Don't call this method directly.
:param key: The variable name to look up.
"""
if key in self.vars:
return self.vars[key]
if key in self.parent:
return self.parent[key]
return missing
def get_exported(self) -> t.Dict[str, t.Any]:
"""Get a new dict with the exported variables."""
return {k: self.vars[k] for k in self.exported_vars}
def get_all(self) -> t.Dict[str, t.Any]:
"""Return the complete context as dict including the exported
variables. For optimizations reasons this might not return an
actual copy so be careful with using it.
"""
if not self.vars:
return self.parent
if not self.parent:
return self.vars
return dict(self.parent, **self.vars)
@internalcode
def call(
__self, __obj: t.Callable, *args: t.Any, **kwargs: t.Any # noqa: B902
) -> t.Union[t.Any, "Undefined"]:
"""Call the callable with the arguments and keyword arguments
provided but inject the active context or environment as first
argument if the callable has :func:`pass_context` or
:func:`pass_environment`.
"""
if __debug__:
__traceback_hide__ = True # noqa
# Allow callable classes to take a context
if (
hasattr(__obj, "__call__") # noqa: B004
and _PassArg.from_obj(__obj.__call__) is not None
):
__obj = __obj.__call__
pass_arg = _PassArg.from_obj(__obj)
if pass_arg is _PassArg.context:
# the active context should have access to variables set in
# loops and blocks without mutating the context itself
if kwargs.get("_loop_vars"):
__self = __self.derived(kwargs["_loop_vars"])
if kwargs.get("_block_vars"):
__self = __self.derived(kwargs["_block_vars"])
args = (__self,) + args
elif pass_arg is _PassArg.eval_context:
args = (__self.eval_ctx,) + args
elif pass_arg is _PassArg.environment:
args = (__self.environment,) + args
kwargs.pop("_block_vars", None)
kwargs.pop("_loop_vars", None)
try:
return __obj(*args, **kwargs)
except StopIteration:
return __self.environment.undefined(
"value was undefined because a callable raised a"
" StopIteration exception"
)
def derived(self, locals: t.Optional[t.Dict[str, t.Any]] = None) -> "Context":
"""Internal helper function to create a derived context. This is
used in situations where the system needs a new context in the same
template that is independent.
"""
context = new_context(
self.environment, self.name, {}, self.get_all(), True, None, locals
)
context.eval_ctx = self.eval_ctx
context.blocks.update((k, list(v)) for k, v in self.blocks.items())
return context
keys = _dict_method_all(dict.keys)
values = _dict_method_all(dict.values)
items = _dict_method_all(dict.items)
def __contains__(self, name: str) -> bool:
return name in self.vars or name in self.parent
def __getitem__(self, key: str) -> t.Any:
"""Look up a variable by name with ``[]`` syntax, or raise a
``KeyError`` if the key is not found.
"""
item = self.resolve_or_missing(key)
if item is missing:
raise KeyError(key)
return item
def __repr__(self) -> str:
return f"<{type(self).__name__} {self.get_all()!r} of {self.name!r}>"
class BlockReference:
"""One block on a template reference."""
def __init__(
self,
name: str,
context: "Context",
stack: t.List[t.Callable[["Context"], t.Iterator[str]]],
depth: int,
) -> None:
self.name = name
self._context = context
self._stack = stack
self._depth = depth
@property
def super(self) -> t.Union["BlockReference", "Undefined"]:
"""Super the block."""
if self._depth + 1 >= len(self._stack):
return self._context.environment.undefined(
f"there is no parent block called {self.name!r}.", name="super"
)
return BlockReference(self.name, self._context, self._stack, self._depth + 1)
@internalcode
async def _async_call(self) -> str:
rv = concat(
[x async for x in self._stack[self._depth](self._context)] # type: ignore
)
if self._context.eval_ctx.autoescape:
return Markup(rv)
return rv
@internalcode
def __call__(self) -> str:
if self._context.environment.is_async:
return self._async_call() # type: ignore
rv = concat(self._stack[self._depth](self._context))
if self._context.eval_ctx.autoescape:
return Markup(rv)
return rv
class LoopContext:
"""A wrapper iterable for dynamic ``for`` loops, with information
about the loop and iteration.
"""
#: Current iteration of the loop, starting at 0.
index0 = -1
_length: t.Optional[int] = None
_after: t.Any = missing
_current: t.Any = missing
_before: t.Any = missing
_last_changed_value: t.Any = missing
def __init__(
self,
iterable: t.Iterable[V],
undefined: t.Type["Undefined"],
recurse: t.Optional["LoopRenderFunc"] = None,
depth0: int = 0,
) -> None:
"""
:param iterable: Iterable to wrap.
:param undefined: :class:`Undefined` class to use for next and
previous items.
:param recurse: The function to render the loop body when the
loop is marked recursive.
:param depth0: Incremented when looping recursively.
"""
self._iterable = iterable
self._iterator = self._to_iterator(iterable)
self._undefined = undefined
self._recurse = recurse
#: How many levels deep a recursive loop currently is, starting at 0.
self.depth0 = depth0
@staticmethod
def _to_iterator(iterable: t.Iterable[V]) -> t.Iterator[V]:
return iter(iterable)
@property
def length(self) -> int:
"""Length of the iterable.
If the iterable is a generator or otherwise does not have a
size, it is eagerly evaluated to get a size.
"""
if self._length is not None:
return self._length
try:
self._length = len(self._iterable) # type: ignore
except TypeError:
iterable = list(self._iterator)
self._iterator = self._to_iterator(iterable)
self._length = len(iterable) + self.index + (self._after is not missing)
return self._length
def __len__(self) -> int:
return self.length
@property
def depth(self) -> int:
"""How many levels deep a recursive loop currently is, starting at 1."""
return self.depth0 + 1
@property
def index(self) -> int:
"""Current iteration of the loop, starting at 1."""
return self.index0 + 1
@property
def revindex0(self) -> int:
"""Number of iterations from the end of the loop, ending at 0.
Requires calculating :attr:`length`.
"""
return self.length - self.index
@property
def revindex(self) -> int:
"""Number of iterations from the end of the loop, ending at 1.
Requires calculating :attr:`length`.
"""
return self.length - self.index0
@property
def first(self) -> bool:
"""Whether this is the first iteration of the loop."""
return self.index0 == 0
def _peek_next(self) -> t.Any:
"""Return the next element in the iterable, or :data:`missing`
if the iterable is exhausted. Only peeks one item ahead, caching
the result in :attr:`_last` for use in subsequent checks. The
cache is reset when :meth:`__next__` is called.
"""
if self._after is not missing:
return self._after
self._after = next(self._iterator, missing)
return self._after
@property
def last(self) -> bool:
"""Whether this is the last iteration of the loop.
Causes the iterable to advance early. See
:func:`itertools.groupby` for issues this can cause.
The :func:`groupby` filter avoids that issue.
"""
return self._peek_next() is missing
@property
def previtem(self) -> t.Union[t.Any, "Undefined"]:
"""The item in the previous iteration. Undefined during the
first iteration.
"""
if self.first:
return self._undefined("there is no previous item")
return self._before
@property
def nextitem(self) -> t.Union[t.Any, "Undefined"]:
"""The item in the next iteration. Undefined during the last
iteration.
Causes the iterable to advance early. See
:func:`itertools.groupby` for issues this can cause.
The :func:`jinja-filters.groupby` filter avoids that issue.
"""
rv = self._peek_next()
if rv is missing:
return self._undefined("there is no next item")
return rv
def cycle(self, *args: V) -> V:
"""Return a value from the given args, cycling through based on
the current :attr:`index0`.
:param args: One or more values to cycle through.
"""
if not args:
raise TypeError("no items for cycling given")
return args[self.index0 % len(args)]
def changed(self, *value: t.Any) -> bool:
"""Return ``True`` if previously called with a different value
(including when called for the first time).
:param value: One or more values to compare to the last call.
"""
if self._last_changed_value != value:
self._last_changed_value = value
return True
return False
def __iter__(self) -> "LoopContext":
return self
def __next__(self) -> t.Tuple[t.Any, "LoopContext"]:
if self._after is not missing:
rv = self._after
self._after = missing
else:
rv = next(self._iterator)
self.index0 += 1
self._before = self._current
self._current = rv
return rv, self
@internalcode
def __call__(self, iterable: t.Iterable[V]) -> str:
"""When iterating over nested data, render the body of the loop
recursively with the given inner iterable data.
The loop must have the ``recursive`` marker for this to work.
"""
if self._recurse is None:
raise TypeError(
"The loop must have the 'recursive' marker to be called recursively."
)
return self._recurse(iterable, self._recurse, depth=self.depth)
def __repr__(self) -> str:
return f"<{type(self).__name__} {self.index}/{self.length}>"
class AsyncLoopContext(LoopContext):
_iterator: t.AsyncIterator[t.Any] # type: ignore
@staticmethod
def _to_iterator( # type: ignore
iterable: t.Union[t.Iterable[V], t.AsyncIterable[V]]
) -> t.AsyncIterator[V]:
return auto_aiter(iterable)
@property
async def length(self) -> int: # type: ignore
if self._length is not None:
return self._length
try:
self._length = len(self._iterable) # type: ignore
except TypeError:
iterable = [x async for x in self._iterator]
self._iterator = self._to_iterator(iterable)
self._length = len(iterable) + self.index + (self._after is not missing)
return self._length
@property
async def revindex0(self) -> int: # type: ignore
return await self.length - self.index
@property
async def revindex(self) -> int: # type: ignore
return await self.length - self.index0
async def _peek_next(self) -> t.Any:
if self._after is not missing:
return self._after
try:
self._after = await self._iterator.__anext__()
except StopAsyncIteration:
self._after = missing
return self._after
@property
async def last(self) -> bool: # type: ignore
return await self._peek_next() is missing
@property
async def nextitem(self) -> t.Union[t.Any, "Undefined"]:
rv = await self._peek_next()
if rv is missing:
return self._undefined("there is no next item")
return rv
def __aiter__(self) -> "AsyncLoopContext":
return self
async def __anext__(self) -> t.Tuple[t.Any, "AsyncLoopContext"]:
if self._after is not missing:
rv = self._after
self._after = missing
else:
rv = await self._iterator.__anext__()
self.index0 += 1
self._before = self._current
self._current = rv
return rv, self
class Macro:
"""Wraps a macro function."""
def __init__(
self,
environment: "Environment",
func: t.Callable[..., str],
name: str,
arguments: t.List[str],
catch_kwargs: bool,
catch_varargs: bool,
caller: bool,
default_autoescape: t.Optional[bool] = None,
):
self._environment = environment
self._func = func
self._argument_count = len(arguments)
self.name = name
self.arguments = arguments
self.catch_kwargs = catch_kwargs
self.catch_varargs = catch_varargs
self.caller = caller
self.explicit_caller = "caller" in arguments
if default_autoescape is None:
if callable(environment.autoescape):
default_autoescape = environment.autoescape(None)
else:
default_autoescape = environment.autoescape
self._default_autoescape = default_autoescape
@internalcode
@pass_eval_context
def __call__(self, *args: t.Any, **kwargs: t.Any) -> str:
# This requires a bit of explanation, In the past we used to
# decide largely based on compile-time information if a macro is
# safe or unsafe. While there was a volatile mode it was largely
# unused for deciding on escaping. This turns out to be
# problematic for macros because whether a macro is safe depends not
# on the escape mode when it was defined, but rather when it was used.
#
# Because however we export macros from the module system and
# there are historic callers that do not pass an eval context (and
# will continue to not pass one), we need to perform an instance
# check here.
#
# This is considered safe because an eval context is not a valid
# argument to callables otherwise anyway. Worst case here is
# that if no eval context is passed we fall back to the compile
# time autoescape flag.
if args and isinstance(args[0], EvalContext):
autoescape = args[0].autoescape
args = args[1:]
else:
autoescape = self._default_autoescape
# try to consume the positional arguments
arguments = list(args[: self._argument_count])
off = len(arguments)
# For information why this is necessary refer to the handling
# of caller in the `macro_body` handler in the compiler.
found_caller = False
# if the number of arguments consumed is not the number of
# arguments expected we start filling in keyword arguments
# and defaults.
if off != self._argument_count:
for name in self.arguments[len(arguments) :]:
try:
value = kwargs.pop(name)
except KeyError:
value = missing
if name == "caller":
found_caller = True
arguments.append(value)
else:
found_caller = self.explicit_caller
# it's important that the order of these arguments does not change
# if not also changed in the compiler's `function_scoping` method.
# the order is caller, keyword arguments, positional arguments!
if self.caller and not found_caller:
caller = kwargs.pop("caller", None)
if caller is None:
caller = self._environment.undefined("No caller defined", name="caller")
arguments.append(caller)
if self.catch_kwargs:
arguments.append(kwargs)
elif kwargs:
if "caller" in kwargs:
raise TypeError(
f"macro {self.name!r} was invoked with two values for the special"
" caller argument. This is most likely a bug."
)
raise TypeError(
f"macro {self.name!r} takes no keyword argument {next(iter(kwargs))!r}"
)
if self.catch_varargs:
arguments.append(args[self._argument_count :])
elif len(args) > self._argument_count:
raise TypeError(
f"macro {self.name!r} takes not more than"
f" {len(self.arguments)} argument(s)"
)
return self._invoke(arguments, autoescape)
async def _async_invoke(self, arguments: t.List[t.Any], autoescape: bool) -> str:
rv = await self._func(*arguments) # type: ignore
if autoescape:
return Markup(rv)
return rv # type: ignore
def _invoke(self, arguments: t.List[t.Any], autoescape: bool) -> str:
if self._environment.is_async:
return self._async_invoke(arguments, autoescape) # type: ignore
rv = self._func(*arguments)
if autoescape:
rv = Markup(rv)
return rv
def __repr__(self) -> str:
name = "anonymous" if self.name is None else repr(self.name)
return f"<{type(self).__name__} {name}>"
class Undefined:
"""The default undefined type. This undefined type can be printed and
iterated over, but every other access will raise an :exc:`UndefinedError`:
>>> foo = Undefined(name='foo')
>>> str(foo)
''
>>> not foo
True
>>> foo + 42
Traceback (most recent call last):
...
jinja2.exceptions.UndefinedError: 'foo' is undefined
"""
__slots__ = (
"_undefined_hint",
"_undefined_obj",
"_undefined_name",
"_undefined_exception",
)
def __init__(
self,
hint: t.Optional[str] = None,
obj: t.Any = missing,
name: t.Optional[str] = None,
exc: t.Type[TemplateRuntimeError] = UndefinedError,
) -> None:
self._undefined_hint = hint
self._undefined_obj = obj
self._undefined_name = name
self._undefined_exception = exc
@property
def _undefined_message(self) -> str:
"""Build a message about the undefined value based on how it was
accessed.
"""
if self._undefined_hint:
return self._undefined_hint
if self._undefined_obj is missing:
return f"{self._undefined_name!r} is undefined"
if not isinstance(self._undefined_name, str):
return (
f"{object_type_repr(self._undefined_obj)} has no"
f" element {self._undefined_name!r}"
)
return (
f"{object_type_repr(self._undefined_obj)!r} has no"
f" attribute {self._undefined_name!r}"
)
@internalcode
def _fail_with_undefined_error(
self, *args: t.Any, **kwargs: t.Any
) -> "te.NoReturn":
"""Raise an :exc:`UndefinedError` when operations are performed
on the undefined value.
"""
raise self._undefined_exception(self._undefined_message)
@internalcode
def __getattr__(self, name: str) -> t.Any:
if name[:2] == "__":
raise AttributeError(name)
return self._fail_with_undefined_error()
__add__ = __radd__ = __sub__ = __rsub__ = _fail_with_undefined_error
__mul__ = __rmul__ = __div__ = __rdiv__ = _fail_with_undefined_error
__truediv__ = __rtruediv__ = _fail_with_undefined_error
__floordiv__ = __rfloordiv__ = _fail_with_undefined_error
__mod__ = __rmod__ = _fail_with_undefined_error
__pos__ = __neg__ = _fail_with_undefined_error
__call__ = __getitem__ = _fail_with_undefined_error
__lt__ = __le__ = __gt__ = __ge__ = _fail_with_undefined_error
__int__ = __float__ = __complex__ = _fail_with_undefined_error
__pow__ = __rpow__ = _fail_with_undefined_error
def __eq__(self, other: t.Any) -> bool:
return type(self) is type(other)
def __ne__(self, other: t.Any) -> bool:
return not self.__eq__(other)
def __hash__(self) -> int:
return id(type(self))
def __str__(self) -> str:
return ""
def __len__(self) -> int:
return 0
def __iter__(self) -> t.Iterator[t.Any]:
yield from ()
async def __aiter__(self) -> t.AsyncIterator[t.Any]:
for _ in ():
yield
def __bool__(self) -> bool:
return False
def __repr__(self) -> str:
return "Undefined"
def make_logging_undefined(
logger: t.Optional["logging.Logger"] = None, base: t.Type[Undefined] = Undefined
) -> t.Type[Undefined]:
"""Given a logger object this returns a new undefined class that will
log certain failures. It will log iterations and printing. If no
logger is given a default logger is created.
Example::
logger = logging.getLogger(__name__)
LoggingUndefined = make_logging_undefined(
logger=logger,
base=Undefined
)
.. versionadded:: 2.8
:param logger: the logger to use. If not provided, a default logger
is created.
:param base: the base class to add logging functionality to. This
defaults to :class:`Undefined`.
"""
if logger is None:
import logging
logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler(sys.stderr))
def _log_message(undef: Undefined) -> None:
logger.warning("Template variable warning: %s", undef._undefined_message)
class LoggingUndefined(base): # type: ignore
__slots__ = ()
def _fail_with_undefined_error( # type: ignore
self, *args: t.Any, **kwargs: t.Any
) -> "te.NoReturn":
try:
super()._fail_with_undefined_error(*args, **kwargs)
except self._undefined_exception as e:
logger.error("Template variable error: %s", e) # type: ignore
raise e
def __str__(self) -> str:
_log_message(self)
return super().__str__() # type: ignore
def __iter__(self) -> t.Iterator[t.Any]:
_log_message(self)
return super().__iter__() # type: ignore
def __bool__(self) -> bool:
_log_message(self)
return super().__bool__() # type: ignore
return LoggingUndefined
class ChainableUndefined(Undefined):
"""An undefined that is chainable, where both ``__getattr__`` and
``__getitem__`` return itself rather than raising an
:exc:`UndefinedError`.
>>> foo = ChainableUndefined(name='foo')
>>> str(foo.bar['baz'])
''
>>> foo.bar['baz'] + 42
Traceback (most recent call last):
...
jinja2.exceptions.UndefinedError: 'foo' is undefined
.. versionadded:: 2.11.0
"""
__slots__ = ()
def __html__(self) -> str:
return str(self)
def __getattr__(self, _: str) -> "ChainableUndefined":
return self
__getitem__ = __getattr__ # type: ignore
class DebugUndefined(Undefined):
"""An undefined that returns the debug info when printed.
>>> foo = DebugUndefined(name='foo')
>>> str(foo)
'{{ foo }}'
>>> not foo
True
>>> foo + 42
Traceback (most recent call last):
...
jinja2.exceptions.UndefinedError: 'foo' is undefined
"""
__slots__ = ()
def __str__(self) -> str:
if self._undefined_hint:
message = f"undefined value printed: {self._undefined_hint}"
elif self._undefined_obj is missing:
message = self._undefined_name # type: ignore
else:
message = (
f"no such element: {object_type_repr(self._undefined_obj)}"
f"[{self._undefined_name!r}]"
)
return f"{{{{ {message} }}}}"
class StrictUndefined(Undefined):
"""An undefined that barks on print and iteration as well as boolean
tests and all kinds of comparisons. In other words: you can do nothing
with it except checking if it's defined using the `defined` test.
>>> foo = StrictUndefined(name='foo')
>>> str(foo)
Traceback (most recent call last):
...
jinja2.exceptions.UndefinedError: 'foo' is undefined
>>> not foo
Traceback (most recent call last):
...
jinja2.exceptions.UndefinedError: 'foo' is undefined
>>> foo + 42
Traceback (most recent call last):
...
jinja2.exceptions.UndefinedError: 'foo' is undefined
"""
__slots__ = ()
__iter__ = __str__ = __len__ = Undefined._fail_with_undefined_error
__eq__ = __ne__ = __bool__ = __hash__ = Undefined._fail_with_undefined_error
__contains__ = Undefined._fail_with_undefined_error
# Remove slots attributes, after the metaclass is applied they are
# unneeded and contain wrong data for subclasses.
del (
Undefined.__slots__,
ChainableUndefined.__slots__,
DebugUndefined.__slots__,
StrictUndefined.__slots__,
)
"""A sandbox layer that ensures unsafe operations cannot be performed.
Useful when the template itself comes from an untrusted source.
"""
import operator
import types
import typing as t
from _string import formatter_field_name_split # type: ignore
from collections import abc
from collections import deque
from string import Formatter
from markupsafe import EscapeFormatter
from markupsafe import Markup
from .environment import Environment
from .exceptions import SecurityError
from .runtime import Context
from .runtime import Undefined
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
#: maximum number of items a range may produce
MAX_RANGE = 100000
#: Unsafe function attributes.
UNSAFE_FUNCTION_ATTRIBUTES: t.Set[str] = set()
#: Unsafe method attributes. Function attributes are unsafe for methods too.
UNSAFE_METHOD_ATTRIBUTES: t.Set[str] = set()
#: unsafe generator attributes.
UNSAFE_GENERATOR_ATTRIBUTES = {"gi_frame", "gi_code"}
#: unsafe attributes on coroutines
UNSAFE_COROUTINE_ATTRIBUTES = {"cr_frame", "cr_code"}
#: unsafe attributes on async generators
UNSAFE_ASYNC_GENERATOR_ATTRIBUTES = {"ag_code", "ag_frame"}
_mutable_spec: t.Tuple[t.Tuple[t.Type, t.FrozenSet[str]], ...] = (
(
abc.MutableSet,
frozenset(
[
"add",
"clear",
"difference_update",
"discard",
"pop",
"remove",
"symmetric_difference_update",
"update",
]
),
),
(
abc.MutableMapping,
frozenset(["clear", "pop", "popitem", "setdefault", "update"]),
),
(
abc.MutableSequence,
frozenset(["append", "reverse", "insert", "sort", "extend", "remove"]),
),
(
deque,
frozenset(
[
"append",
"appendleft",
"clear",
"extend",
"extendleft",
"pop",
"popleft",
"remove",
"rotate",
]
),
),
)
def inspect_format_method(callable: t.Callable) -> t.Optional[str]:
if not isinstance(
callable, (types.MethodType, types.BuiltinMethodType)
) or callable.__name__ not in ("format", "format_map"):
return None
obj = callable.__self__
if isinstance(obj, str):
return obj
return None
def safe_range(*args: int) -> range:
"""A range that can't generate ranges with a length of more than
MAX_RANGE items.
"""
rng = range(*args)
if len(rng) > MAX_RANGE:
raise OverflowError(
"Range too big. The sandbox blocks ranges larger than"
f" MAX_RANGE ({MAX_RANGE})."
)
return rng
def unsafe(f: F) -> F:
"""Marks a function or method as unsafe.
.. code-block: python
@unsafe
def delete(self):
pass
"""
f.unsafe_callable = True # type: ignore
return f
def is_internal_attribute(obj: t.Any, attr: str) -> bool:
"""Test if the attribute given is an internal python attribute. For
example this function returns `True` for the `func_code` attribute of
python objects. This is useful if the environment method
:meth:`~SandboxedEnvironment.is_safe_attribute` is overridden.
>>> from jinja2.sandbox import is_internal_attribute
>>> is_internal_attribute(str, "mro")
True
>>> is_internal_attribute(str, "upper")
False
"""
if isinstance(obj, types.FunctionType):
if attr in UNSAFE_FUNCTION_ATTRIBUTES:
return True
elif isinstance(obj, types.MethodType):
if attr in UNSAFE_FUNCTION_ATTRIBUTES or attr in UNSAFE_METHOD_ATTRIBUTES:
return True
elif isinstance(obj, type):
if attr == "mro":
return True
elif isinstance(obj, (types.CodeType, types.TracebackType, types.FrameType)):
return True
elif isinstance(obj, types.GeneratorType):
if attr in UNSAFE_GENERATOR_ATTRIBUTES:
return True
elif hasattr(types, "CoroutineType") and isinstance(obj, types.CoroutineType):
if attr in UNSAFE_COROUTINE_ATTRIBUTES:
return True
elif hasattr(types, "AsyncGeneratorType") and isinstance(
obj, types.AsyncGeneratorType
):
if attr in UNSAFE_ASYNC_GENERATOR_ATTRIBUTES:
return True
return attr.startswith("__")
def modifies_known_mutable(obj: t.Any, attr: str) -> bool:
"""This function checks if an attribute on a builtin mutable object
(list, dict, set or deque) or the corresponding ABCs would modify it
if called.
>>> modifies_known_mutable({}, "clear")
True
>>> modifies_known_mutable({}, "keys")
False
>>> modifies_known_mutable([], "append")
True
>>> modifies_known_mutable([], "index")
False
If called with an unsupported object, ``False`` is returned.
>>> modifies_known_mutable("foo", "upper")
False
"""
for typespec, unsafe in _mutable_spec:
if isinstance(obj, typespec):
return attr in unsafe
return False
class SandboxedEnvironment(Environment):
"""The sandboxed environment. It works like the regular environment but
tells the compiler to generate sandboxed code. Additionally subclasses of
this environment may override the methods that tell the runtime what
attributes or functions are safe to access.
If the template tries to access insecure code a :exc:`SecurityError` is
raised. However also other exceptions may occur during the rendering so
the caller has to ensure that all exceptions are caught.
"""
sandboxed = True
#: default callback table for the binary operators. A copy of this is
#: available on each instance of a sandboxed environment as
#: :attr:`binop_table`
default_binop_table: t.Dict[str, t.Callable[[t.Any, t.Any], t.Any]] = {
"+": operator.add,
"-": operator.sub,
"*": operator.mul,
"/": operator.truediv,
"//": operator.floordiv,
"**": operator.pow,
"%": operator.mod,
}
#: default callback table for the unary operators. A copy of this is
#: available on each instance of a sandboxed environment as
#: :attr:`unop_table`
default_unop_table: t.Dict[str, t.Callable[[t.Any], t.Any]] = {
"+": operator.pos,
"-": operator.neg,
}
#: a set of binary operators that should be intercepted. Each operator
#: that is added to this set (empty by default) is delegated to the
#: :meth:`call_binop` method that will perform the operator. The default
#: operator callback is specified by :attr:`binop_table`.
#:
#: The following binary operators are interceptable:
#: ``//``, ``%``, ``+``, ``*``, ``-``, ``/``, and ``**``
#:
#: The default operation form the operator table corresponds to the
#: builtin function. Intercepted calls are always slower than the native
#: operator call, so make sure only to intercept the ones you are
#: interested in.
#:
#: .. versionadded:: 2.6
intercepted_binops: t.FrozenSet[str] = frozenset()
#: a set of unary operators that should be intercepted. Each operator
#: that is added to this set (empty by default) is delegated to the
#: :meth:`call_unop` method that will perform the operator. The default
#: operator callback is specified by :attr:`unop_table`.
#:
#: The following unary operators are interceptable: ``+``, ``-``
#:
#: The default operation form the operator table corresponds to the
#: builtin function. Intercepted calls are always slower than the native
#: operator call, so make sure only to intercept the ones you are
#: interested in.
#:
#: .. versionadded:: 2.6
intercepted_unops: t.FrozenSet[str] = frozenset()
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)
self.globals["range"] = safe_range
self.binop_table = self.default_binop_table.copy()
self.unop_table = self.default_unop_table.copy()
def is_safe_attribute(self, obj: t.Any, attr: str, value: t.Any) -> bool:
"""The sandboxed environment will call this method to check if the
attribute of an object is safe to access. Per default all attributes
starting with an underscore are considered private as well as the
special attributes of internal python objects as returned by the
:func:`is_internal_attribute` function.
"""
return not (attr.startswith("_") or is_internal_attribute(obj, attr))
def is_safe_callable(self, obj: t.Any) -> bool:
"""Check if an object is safely callable. By default callables
are considered safe unless decorated with :func:`unsafe`.
This also recognizes the Django convention of setting
``func.alters_data = True``.
"""
return not (
getattr(obj, "unsafe_callable", False) or getattr(obj, "alters_data", False)
)
def call_binop(
self, context: Context, operator: str, left: t.Any, right: t.Any
) -> t.Any:
"""For intercepted binary operator calls (:meth:`intercepted_binops`)
this function is executed instead of the builtin operator. This can
be used to fine tune the behavior of certain operators.
.. versionadded:: 2.6
"""
return self.binop_table[operator](left, right)
def call_unop(self, context: Context, operator: str, arg: t.Any) -> t.Any:
"""For intercepted unary operator calls (:meth:`intercepted_unops`)
this function is executed instead of the builtin operator. This can
be used to fine tune the behavior of certain operators.
.. versionadded:: 2.6
"""
return self.unop_table[operator](arg)
def getitem(
self, obj: t.Any, argument: t.Union[str, t.Any]
) -> t.Union[t.Any, Undefined]:
"""Subscribe an object from sandboxed code."""
try:
return obj[argument]
except (TypeError, LookupError):
if isinstance(argument, str):
try:
attr = str(argument)
except Exception:
pass
else:
try:
value = getattr(obj, attr)
except AttributeError:
pass
else:
if self.is_safe_attribute(obj, argument, value):
return value
return self.unsafe_undefined(obj, argument)
return self.undefined(obj=obj, name=argument)
def getattr(self, obj: t.Any, attribute: str) -> t.Union[t.Any, Undefined]:
"""Subscribe an object from sandboxed code and prefer the
attribute. The attribute passed *must* be a bytestring.
"""
try:
value = getattr(obj, attribute)
except AttributeError:
try:
return obj[attribute]
except (TypeError, LookupError):
pass
else:
if self.is_safe_attribute(obj, attribute, value):
return value
return self.unsafe_undefined(obj, attribute)
return self.undefined(obj=obj, name=attribute)
def unsafe_undefined(self, obj: t.Any, attribute: str) -> Undefined:
"""Return an undefined object for unsafe attributes."""
return self.undefined(
f"access to attribute {attribute!r} of"
f" {type(obj).__name__!r} object is unsafe.",
name=attribute,
obj=obj,
exc=SecurityError,
)
def format_string(
self,
s: str,
args: t.Tuple[t.Any, ...],
kwargs: t.Dict[str, t.Any],
format_func: t.Optional[t.Callable] = None,
) -> str:
"""If a format call is detected, then this is routed through this
method so that our safety sandbox can be used for it.
"""
formatter: SandboxedFormatter
if isinstance(s, Markup):
formatter = SandboxedEscapeFormatter(self, escape=s.escape)
else:
formatter = SandboxedFormatter(self)
if format_func is not None and format_func.__name__ == "format_map":
if len(args) != 1 or kwargs:
raise TypeError(
"format_map() takes exactly one argument"
f" {len(args) + (kwargs is not None)} given"
)
kwargs = args[0]
args = ()
rv = formatter.vformat(s, args, kwargs)
return type(s)(rv)
def call(
__self, # noqa: B902
__context: Context,
__obj: t.Any,
*args: t.Any,
**kwargs: t.Any,
) -> t.Any:
"""Call an object from sandboxed code."""
fmt = inspect_format_method(__obj)
if fmt is not None:
return __self.format_string(fmt, args, kwargs, __obj)
# the double prefixes are to avoid double keyword argument
# errors when proxying the call.
if not __self.is_safe_callable(__obj):
raise SecurityError(f"{__obj!r} is not safely callable")
return __context.call(__obj, *args, **kwargs)
class ImmutableSandboxedEnvironment(SandboxedEnvironment):
"""Works exactly like the regular `SandboxedEnvironment` but does not
permit modifications on the builtin mutable objects `list`, `set`, and
`dict` by using the :func:`modifies_known_mutable` function.
"""
def is_safe_attribute(self, obj: t.Any, attr: str, value: t.Any) -> bool:
if not super().is_safe_attribute(obj, attr, value):
return False
return not modifies_known_mutable(obj, attr)
class SandboxedFormatter(Formatter):
def __init__(self, env: Environment, **kwargs: t.Any) -> None:
self._env = env
super().__init__(**kwargs)
def get_field(
self, field_name: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
) -> t.Tuple[t.Any, str]:
first, rest = formatter_field_name_split(field_name)
obj = self.get_value(first, args, kwargs)
for is_attr, i in rest:
if is_attr:
obj = self._env.getattr(obj, i)
else:
obj = self._env.getitem(obj, i)
return obj, first
class SandboxedEscapeFormatter(SandboxedFormatter, EscapeFormatter):
pass
"""Built-in template tests used with the ``is`` operator."""
import operator
import typing as t
from collections import abc
from numbers import Number
from .runtime import Undefined
from .utils import pass_environment
if t.TYPE_CHECKING:
from .environment import Environment
def test_odd(value: int) -> bool:
"""Return true if the variable is odd."""
return value % 2 == 1
def test_even(value: int) -> bool:
"""Return true if the variable is even."""
return value % 2 == 0
def test_divisibleby(value: int, num: int) -> bool:
"""Check if a variable is divisible by a number."""
return value % num == 0
def test_defined(value: t.Any) -> bool:
"""Return true if the variable is defined:
.. sourcecode:: jinja
{% if variable is defined %}
value of variable: {{ variable }}
{% else %}
variable is not defined
{% endif %}
See the :func:`default` filter for a simple way to set undefined
variables.
"""
return not isinstance(value, Undefined)
def test_undefined(value: t.Any) -> bool:
"""Like :func:`defined` but the other way round."""
return isinstance(value, Undefined)
@pass_environment
def test_filter(env: "Environment", value: str) -> bool:
"""Check if a filter exists by name. Useful if a filter may be
optionally available.
.. code-block:: jinja
{% if 'markdown' is filter %}
{{ value | markdown }}
{% else %}
{{ value }}
{% endif %}
.. versionadded:: 3.0
"""
return value in env.filters
@pass_environment
def test_test(env: "Environment", value: str) -> bool:
"""Check if a test exists by name. Useful if a test may be
optionally available.
.. code-block:: jinja
{% if 'loud' is test %}
{% if value is loud %}
{{ value|upper }}
{% else %}
{{ value|lower }}
{% endif %}
{% else %}
{{ value }}
{% endif %}
.. versionadded:: 3.0
"""
return value in env.tests
def test_none(value: t.Any) -> bool:
"""Return true if the variable is none."""
return value is None
def test_boolean(value: t.Any) -> bool:
"""Return true if the object is a boolean value.
.. versionadded:: 2.11
"""
return value is True or value is False
def test_false(value: t.Any) -> bool:
"""Return true if the object is False.
.. versionadded:: 2.11
"""
return value is False
def test_true(value: t.Any) -> bool:
"""Return true if the object is True.
.. versionadded:: 2.11
"""
return value is True
# NOTE: The existing 'number' test matches booleans and floats
def test_integer(value: t.Any) -> bool:
"""Return true if the object is an integer.
.. versionadded:: 2.11
"""
return isinstance(value, int) and value is not True and value is not False
# NOTE: The existing 'number' test matches booleans and integers
def test_float(value: t.Any) -> bool:
"""Return true if the object is a float.
.. versionadded:: 2.11
"""
return isinstance(value, float)
def test_lower(value: str) -> bool:
"""Return true if the variable is lowercased."""
return str(value).islower()
def test_upper(value: str) -> bool:
"""Return true if the variable is uppercased."""
return str(value).isupper()
def test_string(value: t.Any) -> bool:
"""Return true if the object is a string."""
return isinstance(value, str)
def test_mapping(value: t.Any) -> bool:
"""Return true if the object is a mapping (dict etc.).
.. versionadded:: 2.6
"""
return isinstance(value, abc.Mapping)
def test_number(value: t.Any) -> bool:
"""Return true if the variable is a number."""
return isinstance(value, Number)
def test_sequence(value: t.Any) -> bool:
"""Return true if the variable is a sequence. Sequences are variables
that are iterable.
"""
try:
len(value)
value.__getitem__
except Exception:
return False
return True
def test_sameas(value: t.Any, other: t.Any) -> bool:
"""Check if an object points to the same memory address than another
object:
.. sourcecode:: jinja
{% if foo.attribute is sameas false %}
the foo attribute really is the `False` singleton
{% endif %}
"""
return value is other
def test_iterable(value: t.Any) -> bool:
"""Check if it's possible to iterate over an object."""
try:
iter(value)
except TypeError:
return False
return True
def test_escaped(value: t.Any) -> bool:
"""Check if the value is escaped."""
return hasattr(value, "__html__")
def test_in(value: t.Any, seq: t.Container) -> bool:
"""Check if value is in seq.
.. versionadded:: 2.10
"""
return value in seq
TESTS = {
"odd": test_odd,
"even": test_even,
"divisibleby": test_divisibleby,
"defined": test_defined,
"undefined": test_undefined,
"filter": test_filter,
"test": test_test,
"none": test_none,
"boolean": test_boolean,
"false": test_false,
"true": test_true,
"integer": test_integer,
"float": test_float,
"lower": test_lower,
"upper": test_upper,
"string": test_string,
"mapping": test_mapping,
"number": test_number,
"sequence": test_sequence,
"iterable": test_iterable,
"callable": callable,
"sameas": test_sameas,
"escaped": test_escaped,
"in": test_in,
"==": operator.eq,
"eq": operator.eq,
"equalto": operator.eq,
"!=": operator.ne,
"ne": operator.ne,
">": operator.gt,
"gt": operator.gt,
"greaterthan": operator.gt,
"ge": operator.ge,
">=": operator.ge,
"<": operator.lt,
"lt": operator.lt,
"lessthan": operator.lt,
"<=": operator.le,
"le": operator.le,
}
import enum
import json
import os
import re
import typing as t
from collections import abc
from collections import deque
from random import choice
from random import randrange
from threading import Lock
from types import CodeType
from urllib.parse import quote_from_bytes
import markupsafe
if t.TYPE_CHECKING:
import typing_extensions as te
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
# special singleton representing missing values for the runtime
missing: t.Any = type("MissingType", (), {"__repr__": lambda x: "missing"})()
internal_code: t.MutableSet[CodeType] = set()
concat = "".join
def pass_context(f: F) -> F:
"""Pass the :class:`~jinja2.runtime.Context` as the first argument
to the decorated function when called while rendering a template.
Can be used on functions, filters, and tests.
If only ``Context.eval_context`` is needed, use
:func:`pass_eval_context`. If only ``Context.environment`` is
needed, use :func:`pass_environment`.
.. versionadded:: 3.0.0
Replaces ``contextfunction`` and ``contextfilter``.
"""
f.jinja_pass_arg = _PassArg.context # type: ignore
return f
def pass_eval_context(f: F) -> F:
"""Pass the :class:`~jinja2.nodes.EvalContext` as the first argument
to the decorated function when called while rendering a template.
See :ref:`eval-context`.
Can be used on functions, filters, and tests.
If only ``EvalContext.environment`` is needed, use
:func:`pass_environment`.
.. versionadded:: 3.0.0
Replaces ``evalcontextfunction`` and ``evalcontextfilter``.
"""
f.jinja_pass_arg = _PassArg.eval_context # type: ignore
return f
def pass_environment(f: F) -> F:
"""Pass the :class:`~jinja2.Environment` as the first argument to
the decorated function when called while rendering a template.
Can be used on functions, filters, and tests.
.. versionadded:: 3.0.0
Replaces ``environmentfunction`` and ``environmentfilter``.
"""
f.jinja_pass_arg = _PassArg.environment # type: ignore
return f
class _PassArg(enum.Enum):
context = enum.auto()
eval_context = enum.auto()
environment = enum.auto()
@classmethod
def from_obj(cls, obj: F) -> t.Optional["_PassArg"]:
if hasattr(obj, "jinja_pass_arg"):
return obj.jinja_pass_arg # type: ignore
return None
def internalcode(f: F) -> F:
"""Marks the function as internally used"""
internal_code.add(f.__code__)
return f
def is_undefined(obj: t.Any) -> bool:
"""Check if the object passed is undefined. This does nothing more than
performing an instance check against :class:`Undefined` but looks nicer.
This can be used for custom filters or tests that want to react to
undefined variables. For example a custom default filter can look like
this::
def default(var, default=''):
if is_undefined(var):
return default
return var
"""
from .runtime import Undefined
return isinstance(obj, Undefined)
def consume(iterable: t.Iterable[t.Any]) -> None:
"""Consumes an iterable without doing anything with it."""
for _ in iterable:
pass
def clear_caches() -> None:
"""Jinja keeps internal caches for environments and lexers. These are
used so that Jinja doesn't have to recreate environments and lexers all
the time. Normally you don't have to care about that but if you are
measuring memory consumption you may want to clean the caches.
"""
from .environment import get_spontaneous_environment
from .lexer import _lexer_cache
get_spontaneous_environment.cache_clear()
_lexer_cache.clear()
def import_string(import_name: str, silent: bool = False) -> t.Any:
"""Imports an object based on a string. This is useful if you want to
use import paths as endpoints or something similar. An import path can
be specified either in dotted notation (``xml.sax.saxutils.escape``)
or with a colon as object delimiter (``xml.sax.saxutils:escape``).
If the `silent` is True the return value will be `None` if the import
fails.
:return: imported object
"""
try:
if ":" in import_name:
module, obj = import_name.split(":", 1)
elif "." in import_name:
module, _, obj = import_name.rpartition(".")
else:
return __import__(import_name)
return getattr(__import__(module, None, None, [obj]), obj)
except (ImportError, AttributeError):
if not silent:
raise
def open_if_exists(filename: str, mode: str = "rb") -> t.Optional[t.IO]:
"""Returns a file descriptor for the filename if that file exists,
otherwise ``None``.
"""
if not os.path.isfile(filename):
return None
return open(filename, mode)
def object_type_repr(obj: t.Any) -> str:
"""Returns the name of the object's type. For some recognized
singletons the name of the object is returned instead. (For
example for `None` and `Ellipsis`).
"""
if obj is None:
return "None"
elif obj is Ellipsis:
return "Ellipsis"
cls = type(obj)
if cls.__module__ == "builtins":
return f"{cls.__name__} object"
return f"{cls.__module__}.{cls.__name__} object"
def pformat(obj: t.Any) -> str:
"""Format an object using :func:`pprint.pformat`."""
from pprint import pformat
return pformat(obj)
_http_re = re.compile(
r"""
^
(
(https?://|www\.) # scheme or www
(([\w%-]+\.)+)? # subdomain
(
[a-z]{2,63} # basic tld
|
xn--[\w%]{2,59} # idna tld
)
|
([\w%-]{2,63}\.)+ # basic domain
(com|net|int|edu|gov|org|info|mil) # basic tld
|
(https?://) # scheme
(
(([\d]{1,3})(\.[\d]{1,3}){3}) # IPv4
|
(\[([\da-f]{0,4}:){2}([\da-f]{0,4}:?){1,6}]) # IPv6
)
)
(?::[\d]{1,5})? # port
(?:[/?#]\S*)? # path, query, and fragment
$
""",
re.IGNORECASE | re.VERBOSE,
)
_email_re = re.compile(r"^\S+@\w[\w.-]*\.\w+$")
def urlize(
text: str,
trim_url_limit: t.Optional[int] = None,
rel: t.Optional[str] = None,
target: t.Optional[str] = None,
extra_schemes: t.Optional[t.Iterable[str]] = None,
) -> str:
"""Convert URLs in text into clickable links.
This may not recognize links in some situations. Usually, a more
comprehensive formatter, such as a Markdown library, is a better
choice.
Works on ``http://``, ``https://``, ``www.``, ``mailto:``, and email
addresses. Links with trailing punctuation (periods, commas, closing
parentheses) and leading punctuation (opening parentheses) are
recognized excluding the punctuation. Email addresses that include
header fields are not recognized (for example,
``mailto:address@example.com?cc=copy@example.com``).
:param text: Original text containing URLs to link.
:param trim_url_limit: Shorten displayed URL values to this length.
:param target: Add the ``target`` attribute to links.
:param rel: Add the ``rel`` attribute to links.
:param extra_schemes: Recognize URLs that start with these schemes
in addition to the default behavior.
.. versionchanged:: 3.0
The ``extra_schemes`` parameter was added.
.. versionchanged:: 3.0
Generate ``https://`` links for URLs without a scheme.
.. versionchanged:: 3.0
The parsing rules were updated. Recognize email addresses with
or without the ``mailto:`` scheme. Validate IP addresses. Ignore
parentheses and brackets in more cases.
"""
if trim_url_limit is not None:
def trim_url(x: str) -> str:
if len(x) > trim_url_limit:
return f"{x[:trim_url_limit]}..."
return x
else:
def trim_url(x: str) -> str:
return x
words = re.split(r"(\s+)", str(markupsafe.escape(text)))
rel_attr = f' rel="{markupsafe.escape(rel)}"' if rel else ""
target_attr = f' target="{markupsafe.escape(target)}"' if target else ""
for i, word in enumerate(words):
head, middle, tail = "", word, ""
match = re.match(r"^([(<]|&lt;)+", middle)
if match:
head = match.group()
middle = middle[match.end() :]
# Unlike lead, which is anchored to the start of the string,
# need to check that the string ends with any of the characters
# before trying to match all of them, to avoid backtracking.
if middle.endswith((")", ">", ".", ",", "\n", "&gt;")):
match = re.search(r"([)>.,\n]|&gt;)+$", middle)
if match:
tail = match.group()
middle = middle[: match.start()]
# Prefer balancing parentheses in URLs instead of ignoring a
# trailing character.
for start_char, end_char in ("(", ")"), ("<", ">"), ("&lt;", "&gt;"):
start_count = middle.count(start_char)
if start_count <= middle.count(end_char):
# Balanced, or lighter on the left
continue
# Move as many as possible from the tail to balance
for _ in range(min(start_count, tail.count(end_char))):
end_index = tail.index(end_char) + len(end_char)
# Move anything in the tail before the end char too
middle += tail[:end_index]
tail = tail[end_index:]
if _http_re.match(middle):
if middle.startswith("https://") or middle.startswith("http://"):
middle = (
f'<a href="{middle}"{rel_attr}{target_attr}>{trim_url(middle)}</a>'
)
else:
middle = (
f'<a href="https://{middle}"{rel_attr}{target_attr}>'
f"{trim_url(middle)}</a>"
)
elif middle.startswith("mailto:") and _email_re.match(middle[7:]):
middle = f'<a href="{middle}">{middle[7:]}</a>'
elif (
"@" in middle
and not middle.startswith("www.")
and ":" not in middle
and _email_re.match(middle)
):
middle = f'<a href="mailto:{middle}">{middle}</a>'
elif extra_schemes is not None:
for scheme in extra_schemes:
if middle != scheme and middle.startswith(scheme):
middle = f'<a href="{middle}"{rel_attr}{target_attr}>{middle}</a>'
words[i] = f"{head}{middle}{tail}"
return "".join(words)
def generate_lorem_ipsum(
n: int = 5, html: bool = True, min: int = 20, max: int = 100
) -> str:
"""Generate some lorem ipsum for the template."""
from .constants import LOREM_IPSUM_WORDS
words = LOREM_IPSUM_WORDS.split()
result = []
for _ in range(n):
next_capitalized = True
last_comma = last_fullstop = 0
word = None
last = None
p = []
# each paragraph contains out of 20 to 100 words.
for idx, _ in enumerate(range(randrange(min, max))):
while True:
word = choice(words)
if word != last:
last = word
break
if next_capitalized:
word = word.capitalize()
next_capitalized = False
# add commas
if idx - randrange(3, 8) > last_comma:
last_comma = idx
last_fullstop += 2
word += ","
# add end of sentences
if idx - randrange(10, 20) > last_fullstop:
last_comma = last_fullstop = idx
word += "."
next_capitalized = True
p.append(word)
# ensure that the paragraph ends with a dot.
p_str = " ".join(p)
if p_str.endswith(","):
p_str = p_str[:-1] + "."
elif not p_str.endswith("."):
p_str += "."
result.append(p_str)
if not html:
return "\n\n".join(result)
return markupsafe.Markup(
"\n".join(f"<p>{markupsafe.escape(x)}</p>" for x in result)
)
def url_quote(obj: t.Any, charset: str = "utf-8", for_qs: bool = False) -> str:
"""Quote a string for use in a URL using the given charset.
:param obj: String or bytes to quote. Other types are converted to
string then encoded to bytes using the given charset.
:param charset: Encode text to bytes using this charset.
:param for_qs: Quote "/" and use "+" for spaces.
"""
if not isinstance(obj, bytes):
if not isinstance(obj, str):
obj = str(obj)
obj = obj.encode(charset)
safe = b"" if for_qs else b"/"
rv = quote_from_bytes(obj, safe)
if for_qs:
rv = rv.replace("%20", "+")
return rv
@abc.MutableMapping.register
class LRUCache:
"""A simple LRU Cache implementation."""
# this is fast for small capacities (something below 1000) but doesn't
# scale. But as long as it's only used as storage for templates this
# won't do any harm.
def __init__(self, capacity: int) -> None:
self.capacity = capacity
self._mapping: t.Dict[t.Any, t.Any] = {}
self._queue: "te.Deque[t.Any]" = deque()
self._postinit()
def _postinit(self) -> None:
# alias all queue methods for faster lookup
self._popleft = self._queue.popleft
self._pop = self._queue.pop
self._remove = self._queue.remove
self._wlock = Lock()
self._append = self._queue.append
def __getstate__(self) -> t.Mapping[str, t.Any]:
return {
"capacity": self.capacity,
"_mapping": self._mapping,
"_queue": self._queue,
}
def __setstate__(self, d: t.Mapping[str, t.Any]) -> None:
self.__dict__.update(d)
self._postinit()
def __getnewargs__(self) -> t.Tuple:
return (self.capacity,)
def copy(self) -> "LRUCache":
"""Return a shallow copy of the instance."""
rv = self.__class__(self.capacity)
rv._mapping.update(self._mapping)
rv._queue.extend(self._queue)
return rv
def get(self, key: t.Any, default: t.Any = None) -> t.Any:
"""Return an item from the cache dict or `default`"""
try:
return self[key]
except KeyError:
return default
def setdefault(self, key: t.Any, default: t.Any = None) -> t.Any:
"""Set `default` if the key is not in the cache otherwise
leave unchanged. Return the value of this key.
"""
try:
return self[key]
except KeyError:
self[key] = default
return default
def clear(self) -> None:
"""Clear the cache."""
with self._wlock:
self._mapping.clear()
self._queue.clear()
def __contains__(self, key: t.Any) -> bool:
"""Check if a key exists in this cache."""
return key in self._mapping
def __len__(self) -> int:
"""Return the current size of the cache."""
return len(self._mapping)
def __repr__(self) -> str:
return f"<{type(self).__name__} {self._mapping!r}>"
def __getitem__(self, key: t.Any) -> t.Any:
"""Get an item from the cache. Moves the item up so that it has the
highest priority then.
Raise a `KeyError` if it does not exist.
"""
with self._wlock:
rv = self._mapping[key]
if self._queue[-1] != key:
try:
self._remove(key)
except ValueError:
# if something removed the key from the container
# when we read, ignore the ValueError that we would
# get otherwise.
pass
self._append(key)
return rv
def __setitem__(self, key: t.Any, value: t.Any) -> None:
"""Sets the value for an item. Moves the item up so that it
has the highest priority then.
"""
with self._wlock:
if key in self._mapping:
self._remove(key)
elif len(self._mapping) == self.capacity:
del self._mapping[self._popleft()]
self._append(key)
self._mapping[key] = value
def __delitem__(self, key: t.Any) -> None:
"""Remove an item from the cache dict.
Raise a `KeyError` if it does not exist.
"""
with self._wlock:
del self._mapping[key]
try:
self._remove(key)
except ValueError:
pass
def items(self) -> t.Iterable[t.Tuple[t.Any, t.Any]]:
"""Return a list of items."""
result = [(key, self._mapping[key]) for key in list(self._queue)]
result.reverse()
return result
def values(self) -> t.Iterable[t.Any]:
"""Return a list of all values."""
return [x[1] for x in self.items()]
def keys(self) -> t.Iterable[t.Any]:
"""Return a list of all keys ordered by most recent usage."""
return list(self)
def __iter__(self) -> t.Iterator[t.Any]:
return reversed(tuple(self._queue))
def __reversed__(self) -> t.Iterator[t.Any]:
"""Iterate over the keys in the cache dict, oldest items
coming first.
"""
return iter(tuple(self._queue))
__copy__ = copy
def select_autoescape(
enabled_extensions: t.Collection[str] = ("html", "htm", "xml"),
disabled_extensions: t.Collection[str] = (),
default_for_string: bool = True,
default: bool = False,
) -> t.Callable[[t.Optional[str]], bool]:
"""Intelligently sets the initial value of autoescaping based on the
filename of the template. This is the recommended way to configure
autoescaping if you do not want to write a custom function yourself.
If you want to enable it for all templates created from strings or
for all templates with `.html` and `.xml` extensions::
from jinja2 import Environment, select_autoescape
env = Environment(autoescape=select_autoescape(
enabled_extensions=('html', 'xml'),
default_for_string=True,
))
Example configuration to turn it on at all times except if the template
ends with `.txt`::
from jinja2 import Environment, select_autoescape
env = Environment(autoescape=select_autoescape(
disabled_extensions=('txt',),
default_for_string=True,
default=True,
))
The `enabled_extensions` is an iterable of all the extensions that
autoescaping should be enabled for. Likewise `disabled_extensions` is
a list of all templates it should be disabled for. If a template is
loaded from a string then the default from `default_for_string` is used.
If nothing matches then the initial value of autoescaping is set to the
value of `default`.
For security reasons this function operates case insensitive.
.. versionadded:: 2.9
"""
enabled_patterns = tuple(f".{x.lstrip('.').lower()}" for x in enabled_extensions)
disabled_patterns = tuple(f".{x.lstrip('.').lower()}" for x in disabled_extensions)
def autoescape(template_name: t.Optional[str]) -> bool:
if template_name is None:
return default_for_string
template_name = template_name.lower()
if template_name.endswith(enabled_patterns):
return True
if template_name.endswith(disabled_patterns):
return False
return default
return autoescape
def htmlsafe_json_dumps(
obj: t.Any, dumps: t.Optional[t.Callable[..., str]] = None, **kwargs: t.Any
) -> markupsafe.Markup:
"""Serialize an object to a string of JSON with :func:`json.dumps`,
then replace HTML-unsafe characters with Unicode escapes and mark
the result safe with :class:`~markupsafe.Markup`.
This is available in templates as the ``|tojson`` filter.
The following characters are escaped: ``<``, ``>``, ``&``, ``'``.
The returned string is safe to render in HTML documents and
``<script>`` tags. The exception is in HTML attributes that are
double quoted; either use single quotes or the ``|forceescape``
filter.
:param obj: The object to serialize to JSON.
:param dumps: The ``dumps`` function to use. Defaults to
``env.policies["json.dumps_function"]``, which defaults to
:func:`json.dumps`.
:param kwargs: Extra arguments to pass to ``dumps``. Merged onto
``env.policies["json.dumps_kwargs"]``.
.. versionchanged:: 3.0
The ``dumper`` parameter is renamed to ``dumps``.
.. versionadded:: 2.9
"""
if dumps is None:
dumps = json.dumps
return markupsafe.Markup(
dumps(obj, **kwargs)
.replace("<", "\\u003c")
.replace(">", "\\u003e")
.replace("&", "\\u0026")
.replace("'", "\\u0027")
)
class Cycler:
"""Cycle through values by yield them one at a time, then restarting
once the end is reached. Available as ``cycler`` in templates.
Similar to ``loop.cycle``, but can be used outside loops or across
multiple loops. For example, render a list of folders and files in a
list, alternating giving them "odd" and "even" classes.
.. code-block:: html+jinja
{% set row_class = cycler("odd", "even") %}
<ul class="browser">
{% for folder in folders %}
<li class="folder {{ row_class.next() }}">{{ folder }}
{% endfor %}
{% for file in files %}
<li class="file {{ row_class.next() }}">{{ file }}
{% endfor %}
</ul>
:param items: Each positional argument will be yielded in the order
given for each cycle.
.. versionadded:: 2.1
"""
def __init__(self, *items: t.Any) -> None:
if not items:
raise RuntimeError("at least one item has to be provided")
self.items = items
self.pos = 0
def reset(self) -> None:
"""Resets the current item to the first item."""
self.pos = 0
@property
def current(self) -> t.Any:
"""Return the current item. Equivalent to the item that will be
returned next time :meth:`next` is called.
"""
return self.items[self.pos]
def next(self) -> t.Any:
"""Return the current item, then advance :attr:`current` to the
next item.
"""
rv = self.current
self.pos = (self.pos + 1) % len(self.items)
return rv
__next__ = next
class Joiner:
"""A joining helper for templates."""
def __init__(self, sep: str = ", ") -> None:
self.sep = sep
self.used = False
def __call__(self) -> str:
if not self.used:
self.used = True
return ""
return self.sep
class Namespace:
"""A namespace object that can hold arbitrary attributes. It may be
initialized from a dictionary or with keyword arguments."""
def __init__(*args: t.Any, **kwargs: t.Any) -> None: # noqa: B902
self, args = args[0], args[1:]
self.__attrs = dict(*args, **kwargs)
def __getattribute__(self, name: str) -> t.Any:
# __class__ is needed for the awaitable check in async mode
if name in {"_Namespace__attrs", "__class__"}:
return object.__getattribute__(self, name)
try:
return self.__attrs[name]
except KeyError:
raise AttributeError(name) from None
def __setitem__(self, name: str, value: t.Any) -> None:
self.__attrs[name] = value
def __repr__(self) -> str:
return f"<Namespace {self.__attrs!r}>"
"""API for traversing the AST nodes. Implemented by the compiler and
meta introspection.
"""
import typing as t
from .nodes import Node
if t.TYPE_CHECKING:
import typing_extensions as te
class VisitCallable(te.Protocol):
def __call__(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any:
...
class NodeVisitor:
"""Walks the abstract syntax tree and call visitor functions for every
node found. The visitor functions may return values which will be
forwarded by the `visit` method.
Per default the visitor functions for the nodes are ``'visit_'`` +
class name of the node. So a `TryFinally` node visit function would
be `visit_TryFinally`. This behavior can be changed by overriding
the `get_visitor` function. If no visitor function exists for a node
(return value `None`) the `generic_visit` visitor is used instead.
"""
def get_visitor(self, node: Node) -> "t.Optional[VisitCallable]":
"""Return the visitor function for this node or `None` if no visitor
exists for this node. In that case the generic visit function is
used instead.
"""
return getattr(self, f"visit_{type(node).__name__}", None)
def visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any:
"""Visit a node."""
f = self.get_visitor(node)
if f is not None:
return f(node, *args, **kwargs)
return self.generic_visit(node, *args, **kwargs)
def generic_visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.Any:
"""Called if no explicit visitor function exists for a node."""
for child_node in node.iter_child_nodes():
self.visit(child_node, *args, **kwargs)
class NodeTransformer(NodeVisitor):
"""Walks the abstract syntax tree and allows modifications of nodes.
The `NodeTransformer` will walk the AST and use the return value of the
visitor functions to replace or remove the old node. If the return
value of the visitor function is `None` the node will be removed
from the previous location otherwise it's replaced with the return
value. The return value may be the original node in which case no
replacement takes place.
"""
def generic_visit(self, node: Node, *args: t.Any, **kwargs: t.Any) -> Node:
for field, old_value in node.iter_fields():
if isinstance(old_value, list):
new_values = []
for value in old_value:
if isinstance(value, Node):
value = self.visit(value, *args, **kwargs)
if value is None:
continue
elif not isinstance(value, Node):
new_values.extend(value)
continue
new_values.append(value)
old_value[:] = new_values
elif isinstance(old_value, Node):
new_node = self.visit(old_value, *args, **kwargs)
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)
return node
def visit_list(self, node: Node, *args: t.Any, **kwargs: t.Any) -> t.List[Node]:
"""As transformers may return lists in some places this method
can be used to enforce a list as return value.
"""
rv = self.visit(node, *args, **kwargs)
if not isinstance(rv, list):
return [rv]
return rv
from typing import List, Optional
__version__ = "23.0.1"
def main(args: Optional[List[str]] = None) -> int:
"""This is an internal API only meant for use by pip's own console scripts.
For additional details, see https://github.com/pypa/pip/issues/7498.
"""
from pip._internal.utils.entrypoints import _wrapper
return _wrapper(args)
import os
import sys
import warnings
# Remove '' and current working directory from the first entry
# of sys.path, if present to avoid using current directory
# in pip commands check, freeze, install, list and show,
# when invoked as python -m pip <command>
if sys.path[0] in ("", os.getcwd()):
sys.path.pop(0)
# If we are running from a wheel, add the wheel to sys.path
# This allows the usage python pip-*.whl/pip install pip-*.whl
if __package__ == "":
# __file__ is pip-*.whl/pip/__main__.py
# first dirname call strips of '/__main__.py', second strips off '/pip'
# Resulting path is the name of the wheel itself
# Add that to sys.path so we can import pip
path = os.path.dirname(os.path.dirname(__file__))
sys.path.insert(0, path)
if __name__ == "__main__":
# Work around the error reported in #9540, pending a proper fix.
# Note: It is essential the warning filter is set *before* importing
# pip, as the deprecation happens at import time, not runtime.
warnings.filterwarnings(
"ignore", category=DeprecationWarning, module=".*packaging\\.version"
)
from pip._internal.cli.main import main as _main
sys.exit(_main())
"""Execute exactly this copy of pip, within a different environment.
This file is named as it is, to ensure that this module can't be imported via
an import statement.
"""
# /!\ This version compatibility check section must be Python 2 compatible. /!\
import sys
# Copied from setup.py
PYTHON_REQUIRES = (3, 7)
def version_str(version): # type: ignore
return ".".join(str(v) for v in version)
if sys.version_info[:2] < PYTHON_REQUIRES:
raise SystemExit(
"This version of pip does not support python {} (requires >={}).".format(
version_str(sys.version_info[:2]), version_str(PYTHON_REQUIRES)
)
)
# From here on, we can use Python 3 features, but the syntax must remain
# Python 2 compatible.
import runpy # noqa: E402
from importlib.machinery import PathFinder # noqa: E402
from os.path import dirname # noqa: E402
PIP_SOURCES_ROOT = dirname(dirname(__file__))
class PipImportRedirectingFinder:
@classmethod
def find_spec(self, fullname, path=None, target=None): # type: ignore
if fullname != "pip":
return None
spec = PathFinder.find_spec(fullname, [PIP_SOURCES_ROOT], target)
assert spec, (PIP_SOURCES_ROOT, fullname)
return spec
sys.meta_path.insert(0, PipImportRedirectingFinder())
assert __name__ == "__main__", "Cannot run __pip-runner__.py as a non-main module"
runpy.run_module("pip", run_name="__main__", alter_sys=True)
from typing import List, Optional
import pip._internal.utils.inject_securetransport # noqa
from pip._internal.utils import _log
# init_logging() must be called before any call to logging.getLogger()
# which happens at import of most modules.
_log.init_logging()
def main(args: (Optional[List[str]]) = None) -> int:
"""This is preserved for old console scripts that may still be referencing
it.
For additional details, see https://github.com/pypa/pip/issues/7498.
"""
from pip._internal.utils.entrypoints import _wrapper
return _wrapper(args)
"""Build Environment used for isolation during sdist building
"""
import logging
import os
import pathlib
import site
import sys
import textwrap
from collections import OrderedDict
from types import TracebackType
from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple, Type, Union
from pip._vendor.certifi import where
from pip._vendor.packaging.requirements import Requirement
from pip._vendor.packaging.version import Version
from pip import __file__ as pip_location
from pip._internal.cli.spinners import open_spinner
from pip._internal.locations import get_platlib, get_purelib, get_scheme
from pip._internal.metadata import get_default_environment, get_environment
from pip._internal.utils.subprocess import call_subprocess
from pip._internal.utils.temp_dir import TempDirectory, tempdir_kinds
if TYPE_CHECKING:
from pip._internal.index.package_finder import PackageFinder
logger = logging.getLogger(__name__)
def _dedup(a: str, b: str) -> Union[Tuple[str], Tuple[str, str]]:
return (a, b) if a != b else (a,)
class _Prefix:
def __init__(self, path: str) -> None:
self.path = path
self.setup = False
scheme = get_scheme("", prefix=path)
self.bin_dir = scheme.scripts
self.lib_dirs = _dedup(scheme.purelib, scheme.platlib)
def get_runnable_pip() -> str:
"""Get a file to pass to a Python executable, to run the currently-running pip.
This is used to run a pip subprocess, for installing requirements into the build
environment.
"""
source = pathlib.Path(pip_location).resolve().parent
if not source.is_dir():
# This would happen if someone is using pip from inside a zip file. In that
# case, we can use that directly.
return str(source)
return os.fsdecode(source / "__pip-runner__.py")
def _get_system_sitepackages() -> Set[str]:
"""Get system site packages
Usually from site.getsitepackages,
but fallback on `get_purelib()/get_platlib()` if unavailable
(e.g. in a virtualenv created by virtualenv<20)
Returns normalized set of strings.
"""
if hasattr(site, "getsitepackages"):
system_sites = site.getsitepackages()
else:
# virtualenv < 20 overwrites site.py without getsitepackages
# fallback on get_purelib/get_platlib.
# this is known to miss things, but shouldn't in the cases
# where getsitepackages() has been removed (inside a virtualenv)
system_sites = [get_purelib(), get_platlib()]
return {os.path.normcase(path) for path in system_sites}
class BuildEnvironment:
"""Creates and manages an isolated environment to install build deps"""
def __init__(self) -> None:
temp_dir = TempDirectory(kind=tempdir_kinds.BUILD_ENV, globally_managed=True)
self._prefixes = OrderedDict(
(name, _Prefix(os.path.join(temp_dir.path, name)))
for name in ("normal", "overlay")
)
self._bin_dirs: List[str] = []
self._lib_dirs: List[str] = []
for prefix in reversed(list(self._prefixes.values())):
self._bin_dirs.append(prefix.bin_dir)
self._lib_dirs.extend(prefix.lib_dirs)
# Customize site to:
# - ensure .pth files are honored
# - prevent access to system site packages
system_sites = _get_system_sitepackages()
self._site_dir = os.path.join(temp_dir.path, "site")
if not os.path.exists(self._site_dir):
os.mkdir(self._site_dir)
with open(
os.path.join(self._site_dir, "sitecustomize.py"), "w", encoding="utf-8"
) as fp:
fp.write(
textwrap.dedent(
"""
import os, site, sys
# First, drop system-sites related paths.
original_sys_path = sys.path[:]
known_paths = set()
for path in {system_sites!r}:
site.addsitedir(path, known_paths=known_paths)
system_paths = set(
os.path.normcase(path)
for path in sys.path[len(original_sys_path):]
)
original_sys_path = [
path for path in original_sys_path
if os.path.normcase(path) not in system_paths
]
sys.path = original_sys_path
# Second, add lib directories.
# ensuring .pth file are processed.
for path in {lib_dirs!r}:
assert not path in sys.path
site.addsitedir(path)
"""
).format(system_sites=system_sites, lib_dirs=self._lib_dirs)
)
def __enter__(self) -> None:
self._save_env = {
name: os.environ.get(name, None)
for name in ("PATH", "PYTHONNOUSERSITE", "PYTHONPATH")
}
path = self._bin_dirs[:]
old_path = self._save_env["PATH"]
if old_path:
path.extend(old_path.split(os.pathsep))
pythonpath = [self._site_dir]
os.environ.update(
{
"PATH": os.pathsep.join(path),
"PYTHONNOUSERSITE": "1",
"PYTHONPATH": os.pathsep.join(pythonpath),
}
)
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
for varname, old_value in self._save_env.items():
if old_value is None:
os.environ.pop(varname, None)
else:
os.environ[varname] = old_value
def check_requirements(
self, reqs: Iterable[str]
) -> Tuple[Set[Tuple[str, str]], Set[str]]:
"""Return 2 sets:
- conflicting requirements: set of (installed, wanted) reqs tuples
- missing requirements: set of reqs
"""
missing = set()
conflicting = set()
if reqs:
env = (
get_environment(self._lib_dirs)
if hasattr(self, "_lib_dirs")
else get_default_environment()
)
for req_str in reqs:
req = Requirement(req_str)
# We're explicitly evaluating with an empty extra value, since build
# environments are not provided any mechanism to select specific extras.
if req.marker is not None and not req.marker.evaluate({"extra": ""}):
continue
dist = env.get_distribution(req.name)
if not dist:
missing.add(req_str)
continue
if isinstance(dist.version, Version):
installed_req_str = f"{req.name}=={dist.version}"
else:
installed_req_str = f"{req.name}==={dist.version}"
if not req.specifier.contains(dist.version, prereleases=True):
conflicting.add((installed_req_str, req_str))
# FIXME: Consider direct URL?
return conflicting, missing
def install_requirements(
self,
finder: "PackageFinder",
requirements: Iterable[str],
prefix_as_string: str,
*,
kind: str,
) -> None:
prefix = self._prefixes[prefix_as_string]
assert not prefix.setup
prefix.setup = True
if not requirements:
return
self._install_requirements(
get_runnable_pip(),
finder,
requirements,
prefix,
kind=kind,
)
@staticmethod
def _install_requirements(
pip_runnable: str,
finder: "PackageFinder",
requirements: Iterable[str],
prefix: _Prefix,
*,
kind: str,
) -> None:
args: List[str] = [
sys.executable,
pip_runnable,
"install",
"--ignore-installed",
"--no-user",
"--prefix",
prefix.path,
"--no-warn-script-location",
]
if logger.getEffectiveLevel() <= logging.DEBUG:
args.append("-v")
for format_control in ("no_binary", "only_binary"):
formats = getattr(finder.format_control, format_control)
args.extend(
(
"--" + format_control.replace("_", "-"),
",".join(sorted(formats or {":none:"})),
)
)
index_urls = finder.index_urls
if index_urls:
args.extend(["-i", index_urls[0]])
for extra_index in index_urls[1:]:
args.extend(["--extra-index-url", extra_index])
else:
args.append("--no-index")
for link in finder.find_links:
args.extend(["--find-links", link])
for host in finder.trusted_hosts:
args.extend(["--trusted-host", host])
if finder.allow_all_prereleases:
args.append("--pre")
if finder.prefer_binary:
args.append("--prefer-binary")
args.append("--")
args.extend(requirements)
extra_environ = {"_PIP_STANDALONE_CERT": where()}
with open_spinner(f"Installing {kind}") as spinner:
call_subprocess(
args,
command_desc=f"pip subprocess to install {kind}",
spinner=spinner,
extra_environ=extra_environ,
)
class NoOpBuildEnvironment(BuildEnvironment):
"""A no-op drop-in replacement for BuildEnvironment"""
def __init__(self) -> None:
pass
def __enter__(self) -> None:
pass
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
pass
def cleanup(self) -> None:
pass
def install_requirements(
self,
finder: "PackageFinder",
requirements: Iterable[str],
prefix_as_string: str,
*,
kind: str,
) -> None:
raise NotImplementedError()
"""Cache Management
"""
import hashlib
import json
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
from pip._vendor.packaging.tags import Tag, interpreter_name, interpreter_version
from pip._vendor.packaging.utils import canonicalize_name
from pip._internal.exceptions import InvalidWheelFilename
from pip._internal.models.direct_url import DirectUrl
from pip._internal.models.format_control import FormatControl
from pip._internal.models.link import Link
from pip._internal.models.wheel import Wheel
from pip._internal.utils.temp_dir import TempDirectory, tempdir_kinds
from pip._internal.utils.urls import path_to_url
logger = logging.getLogger(__name__)
ORIGIN_JSON_NAME = "origin.json"
def _hash_dict(d: Dict[str, str]) -> str:
"""Return a stable sha224 of a dictionary."""
s = json.dumps(d, sort_keys=True, separators=(",", ":"), ensure_ascii=True)
return hashlib.sha224(s.encode("ascii")).hexdigest()
class Cache:
"""An abstract class - provides cache directories for data from links
:param cache_dir: The root of the cache.
:param format_control: An object of FormatControl class to limit
binaries being read from the cache.
:param allowed_formats: which formats of files the cache should store.
('binary' and 'source' are the only allowed values)
"""
def __init__(
self, cache_dir: str, format_control: FormatControl, allowed_formats: Set[str]
) -> None:
super().__init__()
assert not cache_dir or os.path.isabs(cache_dir)
self.cache_dir = cache_dir or None
self.format_control = format_control
self.allowed_formats = allowed_formats
_valid_formats = {"source", "binary"}
assert self.allowed_formats.union(_valid_formats) == _valid_formats
def _get_cache_path_parts(self, link: Link) -> List[str]:
"""Get parts of part that must be os.path.joined with cache_dir"""
# We want to generate an url to use as our cache key, we don't want to
# just re-use the URL because it might have other items in the fragment
# and we don't care about those.
key_parts = {"url": link.url_without_fragment}
if link.hash_name is not None and link.hash is not None:
key_parts[link.hash_name] = link.hash
if link.subdirectory_fragment:
key_parts["subdirectory"] = link.subdirectory_fragment
# Include interpreter name, major and minor version in cache key
# to cope with ill-behaved sdists that build a different wheel
# depending on the python version their setup.py is being run on,
# and don't encode the difference in compatibility tags.
# https://github.com/pypa/pip/issues/7296
key_parts["interpreter_name"] = interpreter_name()
key_parts["interpreter_version"] = interpreter_version()
# Encode our key url with sha224, we'll use this because it has similar
# security properties to sha256, but with a shorter total output (and
# thus less secure). However the differences don't make a lot of
# difference for our use case here.
hashed = _hash_dict(key_parts)
# We want to nest the directories some to prevent having a ton of top
# level directories where we might run out of sub directories on some
# FS.
parts = [hashed[:2], hashed[2:4], hashed[4:6], hashed[6:]]
return parts
def _get_candidates(self, link: Link, canonical_package_name: str) -> List[Any]:
can_not_cache = not self.cache_dir or not canonical_package_name or not link
if can_not_cache:
return []
formats = self.format_control.get_allowed_formats(canonical_package_name)
if not self.allowed_formats.intersection(formats):
return []
candidates = []
path = self.get_path_for_link(link)
if os.path.isdir(path):
for candidate in os.listdir(path):
candidates.append((candidate, path))
return candidates
def get_path_for_link(self, link: Link) -> str:
"""Return a directory to store cached items in for link."""
raise NotImplementedError()
def get(
self,
link: Link,
package_name: Optional[str],
supported_tags: List[Tag],
) -> Link:
"""Returns a link to a cached item if it exists, otherwise returns the
passed link.
"""
raise NotImplementedError()
class SimpleWheelCache(Cache):
"""A cache of wheels for future installs."""
def __init__(self, cache_dir: str, format_control: FormatControl) -> None:
super().__init__(cache_dir, format_control, {"binary"})
def get_path_for_link(self, link: Link) -> str:
"""Return a directory to store cached wheels for link
Because there are M wheels for any one sdist, we provide a directory
to cache them in, and then consult that directory when looking up
cache hits.
We only insert things into the cache if they have plausible version
numbers, so that we don't contaminate the cache with things that were
not unique. E.g. ./package might have dozens of installs done for it
and build a version of 0.0...and if we built and cached a wheel, we'd
end up using the same wheel even if the source has been edited.
:param link: The link of the sdist for which this will cache wheels.
"""
parts = self._get_cache_path_parts(link)
assert self.cache_dir
# Store wheels within the root cache_dir
return os.path.join(self.cache_dir, "wheels", *parts)
def get(
self,
link: Link,
package_name: Optional[str],
supported_tags: List[Tag],
) -> Link:
candidates = []
if not package_name:
return link
canonical_package_name = canonicalize_name(package_name)
for wheel_name, wheel_dir in self._get_candidates(link, canonical_package_name):
try:
wheel = Wheel(wheel_name)
except InvalidWheelFilename:
continue
if canonicalize_name(wheel.name) != canonical_package_name:
logger.debug(
"Ignoring cached wheel %s for %s as it "
"does not match the expected distribution name %s.",
wheel_name,
link,
package_name,
)
continue
if not wheel.supported(supported_tags):
# Built for a different python/arch/etc
continue
candidates.append(
(
wheel.support_index_min(supported_tags),
wheel_name,
wheel_dir,
)
)
if not candidates:
return link
_, wheel_name, wheel_dir = min(candidates)
return Link(path_to_url(os.path.join(wheel_dir, wheel_name)))
class EphemWheelCache(SimpleWheelCache):
"""A SimpleWheelCache that creates it's own temporary cache directory"""
def __init__(self, format_control: FormatControl) -> None:
self._temp_dir = TempDirectory(
kind=tempdir_kinds.EPHEM_WHEEL_CACHE,
globally_managed=True,
)
super().__init__(self._temp_dir.path, format_control)
class CacheEntry:
def __init__(
self,
link: Link,
persistent: bool,
):
self.link = link
self.persistent = persistent
self.origin: Optional[DirectUrl] = None
origin_direct_url_path = Path(self.link.file_path).parent / ORIGIN_JSON_NAME
if origin_direct_url_path.exists():
self.origin = DirectUrl.from_json(origin_direct_url_path.read_text())
class WheelCache(Cache):
"""Wraps EphemWheelCache and SimpleWheelCache into a single Cache
This Cache allows for gracefully degradation, using the ephem wheel cache
when a certain link is not found in the simple wheel cache first.
"""
def __init__(
self, cache_dir: str, format_control: Optional[FormatControl] = None
) -> None:
if format_control is None:
format_control = FormatControl()
super().__init__(cache_dir, format_control, {"binary"})
self._wheel_cache = SimpleWheelCache(cache_dir, format_control)
self._ephem_cache = EphemWheelCache(format_control)
def get_path_for_link(self, link: Link) -> str:
return self._wheel_cache.get_path_for_link(link)
def get_ephem_path_for_link(self, link: Link) -> str:
return self._ephem_cache.get_path_for_link(link)
def get(
self,
link: Link,
package_name: Optional[str],
supported_tags: List[Tag],
) -> Link:
cache_entry = self.get_cache_entry(link, package_name, supported_tags)
if cache_entry is None:
return link
return cache_entry.link
def get_cache_entry(
self,
link: Link,
package_name: Optional[str],
supported_tags: List[Tag],
) -> Optional[CacheEntry]:
"""Returns a CacheEntry with a link to a cached item if it exists or
None. The cache entry indicates if the item was found in the persistent
or ephemeral cache.
"""
retval = self._wheel_cache.get(
link=link,
package_name=package_name,
supported_tags=supported_tags,
)
if retval is not link:
return CacheEntry(retval, persistent=True)
retval = self._ephem_cache.get(
link=link,
package_name=package_name,
supported_tags=supported_tags,
)
if retval is not link:
return CacheEntry(retval, persistent=False)
return None
@staticmethod
def record_download_origin(cache_dir: str, download_info: DirectUrl) -> None:
origin_path = Path(cache_dir) / ORIGIN_JSON_NAME
if origin_path.is_file():
origin = DirectUrl.from_json(origin_path.read_text())
# TODO: use DirectUrl.equivalent when https://github.com/pypa/pip/pull/10564
# is merged.
if origin.url != download_info.url:
logger.warning(
"Origin URL %s in cache entry %s does not match download URL %s. "
"This is likely a pip bug or a cache corruption issue.",
origin.url,
cache_dir,
download_info.url,
)
origin_path.write_text(download_info.to_json(), encoding="utf-8")
"""Subpackage containing all of pip's command line interface related code
"""
# This file intentionally does not import submodules
"""Logic that powers autocompletion installed by ``pip completion``.
"""
import optparse
import os
import sys
from itertools import chain
from typing import Any, Iterable, List, Optional
from pip._internal.cli.main_parser import create_main_parser
from pip._internal.commands import commands_dict, create_command
from pip._internal.metadata import get_default_environment
def autocomplete() -> None:
"""Entry Point for completion of main and subcommand options."""
# Don't complete if user hasn't sourced bash_completion file.
if "PIP_AUTO_COMPLETE" not in os.environ:
return
cwords = os.environ["COMP_WORDS"].split()[1:]
cword = int(os.environ["COMP_CWORD"])
try:
current = cwords[cword - 1]
except IndexError:
current = ""
parser = create_main_parser()
subcommands = list(commands_dict)
options = []
# subcommand
subcommand_name: Optional[str] = None
for word in cwords:
if word in subcommands:
subcommand_name = word
break
# subcommand options
if subcommand_name is not None:
# special case: 'help' subcommand has no options
if subcommand_name == "help":
sys.exit(1)
# special case: list locally installed dists for show and uninstall
should_list_installed = not current.startswith("-") and subcommand_name in [
"show",
"uninstall",
]
if should_list_installed:
env = get_default_environment()
lc = current.lower()
installed = [
dist.canonical_name
for dist in env.iter_installed_distributions(local_only=True)
if dist.canonical_name.startswith(lc)
and dist.canonical_name not in cwords[1:]
]
# if there are no dists installed, fall back to option completion
if installed:
for dist in installed:
print(dist)
sys.exit(1)
should_list_installables = (
not current.startswith("-") and subcommand_name == "install"
)
if should_list_installables:
for path in auto_complete_paths(current, "path"):
print(path)
sys.exit(1)
subcommand = create_command(subcommand_name)
for opt in subcommand.parser.option_list_all:
if opt.help != optparse.SUPPRESS_HELP:
for opt_str in opt._long_opts + opt._short_opts:
options.append((opt_str, opt.nargs))
# filter out previously specified options from available options
prev_opts = [x.split("=")[0] for x in cwords[1 : cword - 1]]
options = [(x, v) for (x, v) in options if x not in prev_opts]
# filter options by current input
options = [(k, v) for k, v in options if k.startswith(current)]
# get completion type given cwords and available subcommand options
completion_type = get_path_completion_type(
cwords,
cword,
subcommand.parser.option_list_all,
)
# get completion files and directories if ``completion_type`` is
# ``<file>``, ``<dir>`` or ``<path>``
if completion_type:
paths = auto_complete_paths(current, completion_type)
options = [(path, 0) for path in paths]
for option in options:
opt_label = option[0]
# append '=' to options which require args
if option[1] and option[0][:2] == "--":
opt_label += "="
print(opt_label)
else:
# show main parser options only when necessary
opts = [i.option_list for i in parser.option_groups]
opts.append(parser.option_list)
flattened_opts = chain.from_iterable(opts)
if current.startswith("-"):
for opt in flattened_opts:
if opt.help != optparse.SUPPRESS_HELP:
subcommands += opt._long_opts + opt._short_opts
else:
# get completion type given cwords and all available options
completion_type = get_path_completion_type(cwords, cword, flattened_opts)
if completion_type:
subcommands = list(auto_complete_paths(current, completion_type))
print(" ".join([x for x in subcommands if x.startswith(current)]))
sys.exit(1)
def get_path_completion_type(
cwords: List[str], cword: int, opts: Iterable[Any]
) -> Optional[str]:
"""Get the type of path completion (``file``, ``dir``, ``path`` or None)
:param cwords: same as the environmental variable ``COMP_WORDS``
:param cword: same as the environmental variable ``COMP_CWORD``
:param opts: The available options to check
:return: path completion type (``file``, ``dir``, ``path`` or None)
"""
if cword < 2 or not cwords[cword - 2].startswith("-"):
return None
for opt in opts:
if opt.help == optparse.SUPPRESS_HELP:
continue
for o in str(opt).split("/"):
if cwords[cword - 2].split("=")[0] == o:
if not opt.metavar or any(
x in ("path", "file", "dir") for x in opt.metavar.split("/")
):
return opt.metavar
return None
def auto_complete_paths(current: str, completion_type: str) -> Iterable[str]:
"""If ``completion_type`` is ``file`` or ``path``, list all regular files
and directories starting with ``current``; otherwise only list directories
starting with ``current``.
:param current: The word to be completed
:param completion_type: path completion type(``file``, ``path`` or ``dir``)
:return: A generator of regular files and/or directories
"""
directory, filename = os.path.split(current)
current_path = os.path.abspath(directory)
# Don't complete paths if they can't be accessed
if not os.access(current_path, os.R_OK):
return
filename = os.path.normcase(filename)
# list all files that start with ``filename``
file_list = (
x for x in os.listdir(current_path) if os.path.normcase(x).startswith(filename)
)
for f in file_list:
opt = os.path.join(current_path, f)
comp_file = os.path.normcase(os.path.join(directory, f))
# complete regular files when there is not ``<dir>`` after option
# complete directories when there is ``<file>``, ``<path>`` or
# ``<dir>``after option
if completion_type != "dir" and os.path.isfile(opt):
yield comp_file
elif os.path.isdir(opt):
yield os.path.join(comp_file, "")
"""Base Command class, and related routines"""
import functools
import logging
import logging.config
import optparse
import os
import sys
import traceback
from optparse import Values
from typing import Any, Callable, List, Optional, Tuple
from pip._vendor.rich import traceback as rich_traceback
from pip._internal.cli import cmdoptions
from pip._internal.cli.command_context import CommandContextMixIn
from pip._internal.cli.parser import ConfigOptionParser, UpdatingDefaultsHelpFormatter
from pip._internal.cli.status_codes import (
ERROR,
PREVIOUS_BUILD_DIR_ERROR,
UNKNOWN_ERROR,
VIRTUALENV_NOT_FOUND,
)
from pip._internal.exceptions import (
BadCommand,
CommandError,
DiagnosticPipError,
InstallationError,
NetworkConnectionError,
PreviousBuildDirError,
UninstallationError,
)
from pip._internal.utils.filesystem import check_path_owner
from pip._internal.utils.logging import BrokenStdoutLoggingError, setup_logging
from pip._internal.utils.misc import get_prog, normalize_path
from pip._internal.utils.temp_dir import TempDirectoryTypeRegistry as TempDirRegistry
from pip._internal.utils.temp_dir import global_tempdir_manager, tempdir_registry
from pip._internal.utils.virtualenv import running_under_virtualenv
__all__ = ["Command"]
logger = logging.getLogger(__name__)
class Command(CommandContextMixIn):
usage: str = ""
ignore_require_venv: bool = False
def __init__(self, name: str, summary: str, isolated: bool = False) -> None:
super().__init__()
self.name = name
self.summary = summary
self.parser = ConfigOptionParser(
usage=self.usage,
prog=f"{get_prog()} {name}",
formatter=UpdatingDefaultsHelpFormatter(),
add_help_option=False,
name=name,
description=self.__doc__,
isolated=isolated,
)
self.tempdir_registry: Optional[TempDirRegistry] = None
# Commands should add options to this option group
optgroup_name = f"{self.name.capitalize()} Options"
self.cmd_opts = optparse.OptionGroup(self.parser, optgroup_name)
# Add the general options
gen_opts = cmdoptions.make_option_group(
cmdoptions.general_group,
self.parser,
)
self.parser.add_option_group(gen_opts)
self.add_options()
def add_options(self) -> None:
pass
def handle_pip_version_check(self, options: Values) -> None:
"""
This is a no-op so that commands by default do not do the pip version
check.
"""
# Make sure we do the pip version check if the index_group options
# are present.
assert not hasattr(options, "no_index")
def run(self, options: Values, args: List[str]) -> int:
raise NotImplementedError
def parse_args(self, args: List[str]) -> Tuple[Values, List[str]]:
# factored out for testability
return self.parser.parse_args(args)
def main(self, args: List[str]) -> int:
try:
with self.main_context():
return self._main(args)
finally:
logging.shutdown()
def _main(self, args: List[str]) -> int:
# We must initialize this before the tempdir manager, otherwise the
# configuration would not be accessible by the time we clean up the
# tempdir manager.
self.tempdir_registry = self.enter_context(tempdir_registry())
# Intentionally set as early as possible so globally-managed temporary
# directories are available to the rest of the code.
self.enter_context(global_tempdir_manager())
options, args = self.parse_args(args)
# Set verbosity so that it can be used elsewhere.
self.verbosity = options.verbose - options.quiet
level_number = setup_logging(
verbosity=self.verbosity,
no_color=options.no_color,
user_log_file=options.log,
)
# TODO: Try to get these passing down from the command?
# without resorting to os.environ to hold these.
# This also affects isolated builds and it should.
if options.no_input:
os.environ["PIP_NO_INPUT"] = "1"
if options.exists_action:
os.environ["PIP_EXISTS_ACTION"] = " ".join(options.exists_action)
if options.require_venv and not self.ignore_require_venv:
# If a venv is required check if it can really be found
if not running_under_virtualenv():
logger.critical("Could not find an activated virtualenv (required).")
sys.exit(VIRTUALENV_NOT_FOUND)
if options.cache_dir:
options.cache_dir = normalize_path(options.cache_dir)
if not check_path_owner(options.cache_dir):
logger.warning(
"The directory '%s' or its parent directory is not owned "
"or is not writable by the current user. The cache "
"has been disabled. Check the permissions and owner of "
"that directory. If executing pip with sudo, you should "
"use sudo's -H flag.",
options.cache_dir,
)
options.cache_dir = None
def intercepts_unhandled_exc(
run_func: Callable[..., int]
) -> Callable[..., int]:
@functools.wraps(run_func)
def exc_logging_wrapper(*args: Any) -> int:
try:
status = run_func(*args)
assert isinstance(status, int)
return status
except DiagnosticPipError as exc:
logger.error("[present-rich] %s", exc)
logger.debug("Exception information:", exc_info=True)
return ERROR
except PreviousBuildDirError as exc:
logger.critical(str(exc))
logger.debug("Exception information:", exc_info=True)
return PREVIOUS_BUILD_DIR_ERROR
except (
InstallationError,
UninstallationError,
BadCommand,
NetworkConnectionError,
) as exc:
logger.critical(str(exc))
logger.debug("Exception information:", exc_info=True)
return ERROR
except CommandError as exc:
logger.critical("%s", exc)
logger.debug("Exception information:", exc_info=True)
return ERROR
except BrokenStdoutLoggingError:
# Bypass our logger and write any remaining messages to
# stderr because stdout no longer works.
print("ERROR: Pipe to stdout was broken", file=sys.stderr)
if level_number <= logging.DEBUG:
traceback.print_exc(file=sys.stderr)
return ERROR
except KeyboardInterrupt:
logger.critical("Operation cancelled by user")
logger.debug("Exception information:", exc_info=True)
return ERROR
except BaseException:
logger.critical("Exception:", exc_info=True)
return UNKNOWN_ERROR
return exc_logging_wrapper
try:
if not options.debug_mode:
run = intercepts_unhandled_exc(self.run)
else:
run = self.run
rich_traceback.install(show_locals=True)
return run(options, args)
finally:
self.handle_pip_version_check(options)
"""
shared options and groups
The principle here is to define options once, but *not* instantiate them
globally. One reason being that options with action='append' can carry state
between parses. pip parses general options twice internally, and shouldn't
pass on state. To be consistent, all options will follow this design.
"""
# The following comment should be removed at some point in the future.
# mypy: strict-optional=False
import importlib.util
import logging
import os
import textwrap
from functools import partial
from optparse import SUPPRESS_HELP, Option, OptionGroup, OptionParser, Values
from textwrap import dedent
from typing import Any, Callable, Dict, Optional, Tuple
from pip._vendor.packaging.utils import canonicalize_name
from pip._internal.cli.parser import ConfigOptionParser
from pip._internal.exceptions import CommandError
from pip._internal.locations import USER_CACHE_DIR, get_src_prefix
from pip._internal.models.format_control import FormatControl
from pip._internal.models.index import PyPI
from pip._internal.models.target_python import TargetPython
from pip._internal.utils.hashes import STRONG_HASHES
from pip._internal.utils.misc import strtobool
logger = logging.getLogger(__name__)
def raise_option_error(parser: OptionParser, option: Option, msg: str) -> None:
"""
Raise an option parsing error using parser.error().
Args:
parser: an OptionParser instance.
option: an Option instance.
msg: the error text.
"""
msg = f"{option} error: {msg}"
msg = textwrap.fill(" ".join(msg.split()))
parser.error(msg)
def make_option_group(group: Dict[str, Any], parser: ConfigOptionParser) -> OptionGroup:
"""
Return an OptionGroup object
group -- assumed to be dict with 'name' and 'options' keys
parser -- an optparse Parser
"""
option_group = OptionGroup(parser, group["name"])
for option in group["options"]:
option_group.add_option(option())
return option_group
def check_dist_restriction(options: Values, check_target: bool = False) -> None:
"""Function for determining if custom platform options are allowed.
:param options: The OptionParser options.
:param check_target: Whether or not to check if --target is being used.
"""
dist_restriction_set = any(
[
options.python_version,
options.platforms,
options.abis,
options.implementation,
]
)
binary_only = FormatControl(set(), {":all:"})
sdist_dependencies_allowed = (
options.format_control != binary_only and not options.ignore_dependencies
)
# Installations or downloads using dist restrictions must not combine
# source distributions and dist-specific wheels, as they are not
# guaranteed to be locally compatible.
if dist_restriction_set and sdist_dependencies_allowed:
raise CommandError(
"When restricting platform and interpreter constraints using "
"--python-version, --platform, --abi, or --implementation, "
"either --no-deps must be set, or --only-binary=:all: must be "
"set and --no-binary must not be set (or must be set to "
":none:)."
)
if check_target:
if dist_restriction_set and not options.target_dir:
raise CommandError(
"Can not use any platform or abi specific options unless "
"installing via '--target'"
)
def _path_option_check(option: Option, opt: str, value: str) -> str:
return os.path.expanduser(value)
def _package_name_option_check(option: Option, opt: str, value: str) -> str:
return canonicalize_name(value)
class PipOption(Option):
TYPES = Option.TYPES + ("path", "package_name")
TYPE_CHECKER = Option.TYPE_CHECKER.copy()
TYPE_CHECKER["package_name"] = _package_name_option_check
TYPE_CHECKER["path"] = _path_option_check
###########
# options #
###########
help_: Callable[..., Option] = partial(
Option,
"-h",
"--help",
dest="help",
action="help",
help="Show help.",
)
debug_mode: Callable[..., Option] = partial(
Option,
"--debug",
dest="debug_mode",
action="store_true",
default=False,
help=(
"Let unhandled exceptions propagate outside the main subroutine, "
"instead of logging them to stderr."
),
)
isolated_mode: Callable[..., Option] = partial(
Option,
"--isolated",
dest="isolated_mode",
action="store_true",
default=False,
help=(
"Run pip in an isolated mode, ignoring environment variables and user "
"configuration."
),
)
require_virtualenv: Callable[..., Option] = partial(
Option,
"--require-virtualenv",
"--require-venv",
dest="require_venv",
action="store_true",
default=False,
help=(
"Allow pip to only run in a virtual environment; "
"exit with an error otherwise."
),
)
override_externally_managed: Callable[..., Option] = partial(
Option,
"--break-system-packages",
dest="override_externally_managed",
action="store_true",
help="Allow pip to modify an EXTERNALLY-MANAGED Python installation",
)
python: Callable[..., Option] = partial(
Option,
"--python",
dest="python",
help="Run pip with the specified Python interpreter.",
)
verbose: Callable[..., Option] = partial(
Option,
"-v",
"--verbose",
dest="verbose",
action="count",
default=0,
help="Give more output. Option is additive, and can be used up to 3 times.",
)
no_color: Callable[..., Option] = partial(
Option,
"--no-color",
dest="no_color",
action="store_true",
default=False,
help="Suppress colored output.",
)
version: Callable[..., Option] = partial(
Option,
"-V",
"--version",
dest="version",
action="store_true",
help="Show version and exit.",
)
quiet: Callable[..., Option] = partial(
Option,
"-q",
"--quiet",
dest="quiet",
action="count",
default=0,
help=(
"Give less output. Option is additive, and can be used up to 3"
" times (corresponding to WARNING, ERROR, and CRITICAL logging"
" levels)."
),
)
progress_bar: Callable[..., Option] = partial(
Option,
"--progress-bar",
dest="progress_bar",
type="choice",
choices=["on", "off"],
default="on",
help="Specify whether the progress bar should be used [on, off] (default: on)",
)
log: Callable[..., Option] = partial(
PipOption,
"--log",
"--log-file",
"--local-log",
dest="log",
metavar="path",
type="path",
help="Path to a verbose appending log.",
)
no_input: Callable[..., Option] = partial(
Option,
# Don't ask for input
"--no-input",
dest="no_input",
action="store_true",
default=False,
help="Disable prompting for input.",
)
proxy: Callable[..., Option] = partial(
Option,
"--proxy",
dest="proxy",
type="str",
default="",
help="Specify a proxy in the form scheme://[user:passwd@]proxy.server:port.",
)
retries: Callable[..., Option] = partial(
Option,
"--retries",
dest="retries",
type="int",
default=5,
help="Maximum number of retries each connection should attempt "
"(default %default times).",
)
timeout: Callable[..., Option] = partial(
Option,
"--timeout",
"--default-timeout",
metavar="sec",
dest="timeout",
type="float",
default=15,
help="Set the socket timeout (default %default seconds).",
)
def exists_action() -> Option:
return Option(
# Option when path already exist
"--exists-action",
dest="exists_action",
type="choice",
choices=["s", "i", "w", "b", "a"],
default=[],
action="append",
metavar="action",
help="Default action when a path already exists: "
"(s)witch, (i)gnore, (w)ipe, (b)ackup, (a)bort.",
)
cert: Callable[..., Option] = partial(
PipOption,
"--cert",
dest="cert",
type="path",
metavar="path",
help=(
"Path to PEM-encoded CA certificate bundle. "
"If provided, overrides the default. "
"See 'SSL Certificate Verification' in pip documentation "
"for more information."
),
)
client_cert: Callable[..., Option] = partial(
PipOption,
"--client-cert",
dest="client_cert",
type="path",
default=None,
metavar="path",
help="Path to SSL client certificate, a single file containing the "
"private key and the certificate in PEM format.",
)
index_url: Callable[..., Option] = partial(
Option,
"-i",
"--index-url",
"--pypi-url",
dest="index_url",
metavar="URL",
default=PyPI.simple_url,
help="Base URL of the Python Package Index (default %default). "
"This should point to a repository compliant with PEP 503 "
"(the simple repository API) or a local directory laid out "
"in the same format.",
)
def extra_index_url() -> Option:
return Option(
"--extra-index-url",
dest="extra_index_urls",
metavar="URL",
action="append",
default=[],
help="Extra URLs of package indexes to use in addition to "
"--index-url. Should follow the same rules as "
"--index-url.",
)
no_index: Callable[..., Option] = partial(
Option,
"--no-index",
dest="no_index",
action="store_true",
default=False,
help="Ignore package index (only looking at --find-links URLs instead).",
)
def find_links() -> Option:
return Option(
"-f",
"--find-links",
dest="find_links",
action="append",
default=[],
metavar="url",
help="If a URL or path to an html file, then parse for links to "
"archives such as sdist (.tar.gz) or wheel (.whl) files. "
"If a local path or file:// URL that's a directory, "
"then look for archives in the directory listing. "
"Links to VCS project URLs are not supported.",
)
def trusted_host() -> Option:
return Option(
"--trusted-host",
dest="trusted_hosts",
action="append",
metavar="HOSTNAME",
default=[],
help="Mark this host or host:port pair as trusted, even though it "
"does not have valid or any HTTPS.",
)
def constraints() -> Option:
return Option(
"-c",
"--constraint",
dest="constraints",
action="append",
default=[],
metavar="file",
help="Constrain versions using the given constraints file. "
"This option can be used multiple times.",
)
def requirements() -> Option:
return Option(
"-r",
"--requirement",
dest="requirements",
action="append",
default=[],
metavar="file",
help="Install from the given requirements file. "
"This option can be used multiple times.",
)
def editable() -> Option:
return Option(
"-e",
"--editable",
dest="editables",
action="append",
default=[],
metavar="path/url",
help=(
"Install a project in editable mode (i.e. setuptools "
'"develop mode") from a local project path or a VCS url.'
),
)
def _handle_src(option: Option, opt_str: str, value: str, parser: OptionParser) -> None:
value = os.path.abspath(value)
setattr(parser.values, option.dest, value)
src: Callable[..., Option] = partial(
PipOption,
"--src",
"--source",
"--source-dir",
"--source-directory",
dest="src_dir",
type="path",
metavar="dir",
default=get_src_prefix(),
action="callback",
callback=_handle_src,
help="Directory to check out editable projects into. "
'The default in a virtualenv is "<venv path>/src". '
'The default for global installs is "<current dir>/src".',
)
def _get_format_control(values: Values, option: Option) -> Any:
"""Get a format_control object."""
return getattr(values, option.dest)
def _handle_no_binary(
option: Option, opt_str: str, value: str, parser: OptionParser
) -> None:
existing = _get_format_control(parser.values, option)
FormatControl.handle_mutual_excludes(
value,
existing.no_binary,
existing.only_binary,
)
def _handle_only_binary(
option: Option, opt_str: str, value: str, parser: OptionParser
) -> None:
existing = _get_format_control(parser.values, option)
FormatControl.handle_mutual_excludes(
value,
existing.only_binary,
existing.no_binary,
)
def no_binary() -> Option:
format_control = FormatControl(set(), set())
return Option(
"--no-binary",
dest="format_control",
action="callback",
callback=_handle_no_binary,
type="str",
default=format_control,
help="Do not use binary packages. Can be supplied multiple times, and "
'each time adds to the existing value. Accepts either ":all:" to '
'disable all binary packages, ":none:" to empty the set (notice '
"the colons), or one or more package names with commas between "
"them (no colons). Note that some packages are tricky to compile "
"and may fail to install when this option is used on them.",
)
def only_binary() -> Option:
format_control = FormatControl(set(), set())
return Option(
"--only-binary",
dest="format_control",
action="callback",
callback=_handle_only_binary,
type="str",
default=format_control,
help="Do not use source packages. Can be supplied multiple times, and "
'each time adds to the existing value. Accepts either ":all:" to '
'disable all source packages, ":none:" to empty the set, or one '
"or more package names with commas between them. Packages "
"without binary distributions will fail to install when this "
"option is used on them.",
)
platforms: Callable[..., Option] = partial(
Option,
"--platform",
dest="platforms",
metavar="platform",
action="append",
default=None,
help=(
"Only use wheels compatible with <platform>. Defaults to the "
"platform of the running system. Use this option multiple times to "
"specify multiple platforms supported by the target interpreter."
),
)
# This was made a separate function for unit-testing purposes.
def _convert_python_version(value: str) -> Tuple[Tuple[int, ...], Optional[str]]:
"""
Convert a version string like "3", "37", or "3.7.3" into a tuple of ints.
:return: A 2-tuple (version_info, error_msg), where `error_msg` is
non-None if and only if there was a parsing error.
"""
if not value:
# The empty string is the same as not providing a value.
return (None, None)
parts = value.split(".")
if len(parts) > 3:
return ((), "at most three version parts are allowed")
if len(parts) == 1:
# Then we are in the case of "3" or "37".
value = parts[0]
if len(value) > 1:
parts = [value[0], value[1:]]
try:
version_info = tuple(int(part) for part in parts)
except ValueError:
return ((), "each version part must be an integer")
return (version_info, None)
def _handle_python_version(
option: Option, opt_str: str, value: str, parser: OptionParser
) -> None:
"""
Handle a provided --python-version value.
"""
version_info, error_msg = _convert_python_version(value)
if error_msg is not None:
msg = "invalid --python-version value: {!r}: {}".format(
value,
error_msg,
)
raise_option_error(parser, option=option, msg=msg)
parser.values.python_version = version_info
python_version: Callable[..., Option] = partial(
Option,
"--python-version",
dest="python_version",
metavar="python_version",
action="callback",
callback=_handle_python_version,
type="str",
default=None,
help=dedent(
"""\
The Python interpreter version to use for wheel and "Requires-Python"
compatibility checks. Defaults to a version derived from the running
interpreter. The version can be specified using up to three dot-separated
integers (e.g. "3" for 3.0.0, "3.7" for 3.7.0, or "3.7.3"). A major-minor
version can also be given as a string without dots (e.g. "37" for 3.7.0).
"""
),
)
implementation: Callable[..., Option] = partial(
Option,
"--implementation",
dest="implementation",
metavar="implementation",
default=None,
help=(
"Only use wheels compatible with Python "
"implementation <implementation>, e.g. 'pp', 'jy', 'cp', "
" or 'ip'. If not specified, then the current "
"interpreter implementation is used. Use 'py' to force "
"implementation-agnostic wheels."
),
)
abis: Callable[..., Option] = partial(
Option,
"--abi",
dest="abis",
metavar="abi",
action="append",
default=None,
help=(
"Only use wheels compatible with Python abi <abi>, e.g. 'pypy_41'. "
"If not specified, then the current interpreter abi tag is used. "
"Use this option multiple times to specify multiple abis supported "
"by the target interpreter. Generally you will need to specify "
"--implementation, --platform, and --python-version when using this "
"option."
),
)
def add_target_python_options(cmd_opts: OptionGroup) -> None:
cmd_opts.add_option(platforms())
cmd_opts.add_option(python_version())
cmd_opts.add_option(implementation())
cmd_opts.add_option(abis())
def make_target_python(options: Values) -> TargetPython:
target_python = TargetPython(
platforms=options.platforms,
py_version_info=options.python_version,
abis=options.abis,
implementation=options.implementation,
)
return target_python
def prefer_binary() -> Option:
return Option(
"--prefer-binary",
dest="prefer_binary",
action="store_true",
default=False,
help="Prefer older binary packages over newer source packages.",
)
cache_dir: Callable[..., Option] = partial(
PipOption,
"--cache-dir",
dest="cache_dir",
default=USER_CACHE_DIR,
metavar="dir",
type="path",
help="Store the cache data in <dir>.",
)
def _handle_no_cache_dir(
option: Option, opt: str, value: str, parser: OptionParser
) -> None:
"""
Process a value provided for the --no-cache-dir option.
This is an optparse.Option callback for the --no-cache-dir option.
"""
# The value argument will be None if --no-cache-dir is passed via the
# command-line, since the option doesn't accept arguments. However,
# the value can be non-None if the option is triggered e.g. by an
# environment variable, like PIP_NO_CACHE_DIR=true.
if value is not None:
# Then parse the string value to get argument error-checking.
try:
strtobool(value)
except ValueError as exc:
raise_option_error(parser, option=option, msg=str(exc))
# Originally, setting PIP_NO_CACHE_DIR to a value that strtobool()
# converted to 0 (like "false" or "no") caused cache_dir to be disabled
# rather than enabled (logic would say the latter). Thus, we disable
# the cache directory not just on values that parse to True, but (for
# backwards compatibility reasons) also on values that parse to False.
# In other words, always set it to False if the option is provided in
# some (valid) form.
parser.values.cache_dir = False
no_cache: Callable[..., Option] = partial(
Option,
"--no-cache-dir",
dest="cache_dir",
action="callback",
callback=_handle_no_cache_dir,
help="Disable the cache.",
)
no_deps: Callable[..., Option] = partial(
Option,
"--no-deps",
"--no-dependencies",
dest="ignore_dependencies",
action="store_true",
default=False,
help="Don't install package dependencies.",
)
ignore_requires_python: Callable[..., Option] = partial(
Option,
"--ignore-requires-python",
dest="ignore_requires_python",
action="store_true",
help="Ignore the Requires-Python information.",
)
no_build_isolation: Callable[..., Option] = partial(
Option,
"--no-build-isolation",
dest="build_isolation",
action="store_false",
default=True,
help="Disable isolation when building a modern source distribution. "
"Build dependencies specified by PEP 518 must be already installed "
"if this option is used.",
)
check_build_deps: Callable[..., Option] = partial(
Option,
"--check-build-dependencies",
dest="check_build_deps",
action="store_true",
default=False,
help="Check the build dependencies when PEP517 is used.",
)
def _handle_no_use_pep517(
option: Option, opt: str, value: str, parser: OptionParser
) -> None:
"""
Process a value provided for the --no-use-pep517 option.
This is an optparse.Option callback for the no_use_pep517 option.
"""
# Since --no-use-pep517 doesn't accept arguments, the value argument
# will be None if --no-use-pep517 is passed via the command-line.
# However, the value can be non-None if the option is triggered e.g.
# by an environment variable, for example "PIP_NO_USE_PEP517=true".
if value is not None:
msg = """A value was passed for --no-use-pep517,
probably using either the PIP_NO_USE_PEP517 environment variable
or the "no-use-pep517" config file option. Use an appropriate value
of the PIP_USE_PEP517 environment variable or the "use-pep517"
config file option instead.
"""
raise_option_error(parser, option=option, msg=msg)
# If user doesn't wish to use pep517, we check if setuptools is installed
# and raise error if it is not.
if not importlib.util.find_spec("setuptools"):
msg = "It is not possible to use --no-use-pep517 without setuptools installed."
raise_option_error(parser, option=option, msg=msg)
# Otherwise, --no-use-pep517 was passed via the command-line.
parser.values.use_pep517 = False
use_pep517: Any = partial(
Option,
"--use-pep517",
dest="use_pep517",
action="store_true",
default=None,
help="Use PEP 517 for building source distributions "
"(use --no-use-pep517 to force legacy behaviour).",
)
no_use_pep517: Any = partial(
Option,
"--no-use-pep517",
dest="use_pep517",
action="callback",
callback=_handle_no_use_pep517,
default=None,
help=SUPPRESS_HELP,
)
def _handle_config_settings(
option: Option, opt_str: str, value: str, parser: OptionParser
) -> None:
key, sep, val = value.partition("=")
if sep != "=":
parser.error(f"Arguments to {opt_str} must be of the form KEY=VAL") # noqa
dest = getattr(parser.values, option.dest)
if dest is None:
dest = {}
setattr(parser.values, option.dest, dest)
dest[key] = val
config_settings: Callable[..., Option] = partial(
Option,
"--config-settings",
dest="config_settings",
type=str,
action="callback",
callback=_handle_config_settings,
metavar="settings",
help="Configuration settings to be passed to the PEP 517 build backend. "
"Settings take the form KEY=VALUE. Use multiple --config-settings options "
"to pass multiple keys to the backend.",
)
install_options: Callable[..., Option] = partial(
Option,
"--install-option",
dest="install_options",
action="append",
metavar="options",
help="This option is deprecated. Using this option with location-changing "
"options may cause unexpected behavior. "
"Use pip-level options like --user, --prefix, --root, and --target.",
)
build_options: Callable[..., Option] = partial(
Option,
"--build-option",
dest="build_options",
metavar="options",
action="append",
help="Extra arguments to be supplied to 'setup.py bdist_wheel'.",
)
global_options: Callable[..., Option] = partial(
Option,
"--global-option",
dest="global_options",
action="append",
metavar="options",
help="Extra global options to be supplied to the setup.py "
"call before the install or bdist_wheel command.",
)
no_clean: Callable[..., Option] = partial(
Option,
"--no-clean",
action="store_true",
default=False,
help="Don't clean up build directories.",
)
pre: Callable[..., Option] = partial(
Option,
"--pre",
action="store_true",
default=False,
help="Include pre-release and development versions. By default, "
"pip only finds stable versions.",
)
disable_pip_version_check: Callable[..., Option] = partial(
Option,
"--disable-pip-version-check",
dest="disable_pip_version_check",
action="store_true",
default=False,
help="Don't periodically check PyPI to determine whether a new version "
"of pip is available for download. Implied with --no-index.",
)
root_user_action: Callable[..., Option] = partial(
Option,
"--root-user-action",
dest="root_user_action",
default="warn",
choices=["warn", "ignore"],
help="Action if pip is run as a root user. By default, a warning message is shown.",
)
def _handle_merge_hash(
option: Option, opt_str: str, value: str, parser: OptionParser
) -> None:
"""Given a value spelled "algo:digest", append the digest to a list
pointed to in a dict by the algo name."""
if not parser.values.hashes:
parser.values.hashes = {}
try:
algo, digest = value.split(":", 1)
except ValueError:
parser.error(
"Arguments to {} must be a hash name " # noqa
"followed by a value, like --hash=sha256:"
"abcde...".format(opt_str)
)
if algo not in STRONG_HASHES:
parser.error(
"Allowed hash algorithms for {} are {}.".format( # noqa
opt_str, ", ".join(STRONG_HASHES)
)
)
parser.values.hashes.setdefault(algo, []).append(digest)
hash: Callable[..., Option] = partial(
Option,
"--hash",
# Hash values eventually end up in InstallRequirement.hashes due to
# __dict__ copying in process_line().
dest="hashes",
action="callback",
callback=_handle_merge_hash,
type="string",
help="Verify that the package's archive matches this "
"hash before installing. Example: --hash=sha256:abcdef...",
)
require_hashes: Callable[..., Option] = partial(
Option,
"--require-hashes",
dest="require_hashes",
action="store_true",
default=False,
help="Require a hash to check each requirement against, for "
"repeatable installs. This option is implied when any package in a "
"requirements file has a --hash option.",
)
list_path: Callable[..., Option] = partial(
PipOption,
"--path",
dest="path",
type="path",
action="append",
help="Restrict to the specified installation path for listing "
"packages (can be used multiple times).",
)
def check_list_path_option(options: Values) -> None:
if options.path and (options.user or options.local):
raise CommandError("Cannot combine '--path' with '--user' or '--local'")
list_exclude: Callable[..., Option] = partial(
PipOption,
"--exclude",
dest="excludes",
action="append",
metavar="package",
type="package_name",
help="Exclude specified package from the output",
)
no_python_version_warning: Callable[..., Option] = partial(
Option,
"--no-python-version-warning",
dest="no_python_version_warning",
action="store_true",
default=False,
help="Silence deprecation warnings for upcoming unsupported Pythons.",
)
use_new_feature: Callable[..., Option] = partial(
Option,
"--use-feature",
dest="features_enabled",
metavar="feature",
action="append",
default=[],
choices=[
"fast-deps",
"truststore",
"no-binary-enable-wheel-cache",
],
help="Enable new functionality, that may be backward incompatible.",
)
use_deprecated_feature: Callable[..., Option] = partial(
Option,
"--use-deprecated",
dest="deprecated_features_enabled",
metavar="feature",
action="append",
default=[],
choices=[
"legacy-resolver",
],
help=("Enable deprecated functionality, that will be removed in the future."),
)
##########
# groups #
##########
general_group: Dict[str, Any] = {
"name": "General Options",
"options": [
help_,
debug_mode,
isolated_mode,
require_virtualenv,
python,
verbose,
version,
quiet,
log,
no_input,
proxy,
retries,
timeout,
exists_action,
trusted_host,
cert,
client_cert,
cache_dir,
no_cache,
disable_pip_version_check,
no_color,
no_python_version_warning,
use_new_feature,
use_deprecated_feature,
],
}
index_group: Dict[str, Any] = {
"name": "Package Index Options",
"options": [
index_url,
extra_index_url,
no_index,
find_links,
],
}
from contextlib import ExitStack, contextmanager
from typing import ContextManager, Generator, TypeVar
_T = TypeVar("_T", covariant=True)
class CommandContextMixIn:
def __init__(self) -> None:
super().__init__()
self._in_main_context = False
self._main_context = ExitStack()
@contextmanager
def main_context(self) -> Generator[None, None, None]:
assert not self._in_main_context
self._in_main_context = True
try:
with self._main_context:
yield
finally:
self._in_main_context = False
def enter_context(self, context_provider: ContextManager[_T]) -> _T:
assert self._in_main_context
return self._main_context.enter_context(context_provider)
"""Primary application entrypoint.
"""
import locale
import logging
import os
import sys
from typing import List, Optional
from pip._internal.cli.autocompletion import autocomplete
from pip._internal.cli.main_parser import parse_command
from pip._internal.commands import create_command
from pip._internal.exceptions import PipError
from pip._internal.utils import deprecation
logger = logging.getLogger(__name__)
# Do not import and use main() directly! Using it directly is actively
# discouraged by pip's maintainers. The name, location and behavior of
# this function is subject to change, so calling it directly is not
# portable across different pip versions.
# In addition, running pip in-process is unsupported and unsafe. This is
# elaborated in detail at
# https://pip.pypa.io/en/stable/user_guide/#using-pip-from-your-program.
# That document also provides suggestions that should work for nearly
# all users that are considering importing and using main() directly.
# However, we know that certain users will still want to invoke pip
# in-process. If you understand and accept the implications of using pip
# in an unsupported manner, the best approach is to use runpy to avoid
# depending on the exact location of this entry point.
# The following example shows how to use runpy to invoke pip in that
# case:
#
# sys.argv = ["pip", your, args, here]
# runpy.run_module("pip", run_name="__main__")
#
# Note that this will exit the process after running, unlike a direct
# call to main. As it is not safe to do any processing after calling
# main, this should not be an issue in practice.
def main(args: Optional[List[str]] = None) -> int:
if args is None:
args = sys.argv[1:]
# Configure our deprecation warnings to be sent through loggers
deprecation.install_warning_logger()
autocomplete()
try:
cmd_name, cmd_args = parse_command(args)
except PipError as exc:
sys.stderr.write(f"ERROR: {exc}")
sys.stderr.write(os.linesep)
sys.exit(1)
# Needed for locale.getpreferredencoding(False) to work
# in pip._internal.utils.encoding.auto_decode
try:
locale.setlocale(locale.LC_ALL, "")
except locale.Error as e:
# setlocale can apparently crash if locale are uninitialized
logger.debug("Ignoring error %s when setting locale", e)
command = create_command(cmd_name, isolated=("--isolated" in cmd_args))
return command.main(cmd_args)
"""A single place for constructing and exposing the main parser
"""
import os
import subprocess
import sys
from typing import List, Optional, Tuple
from pip._internal.build_env import get_runnable_pip
from pip._internal.cli import cmdoptions
from pip._internal.cli.parser import ConfigOptionParser, UpdatingDefaultsHelpFormatter
from pip._internal.commands import commands_dict, get_similar_commands
from pip._internal.exceptions import CommandError
from pip._internal.utils.misc import get_pip_version, get_prog
__all__ = ["create_main_parser", "parse_command"]
def create_main_parser() -> ConfigOptionParser:
"""Creates and returns the main parser for pip's CLI"""
parser = ConfigOptionParser(
usage="\n%prog <command> [options]",
add_help_option=False,
formatter=UpdatingDefaultsHelpFormatter(),
name="global",
prog=get_prog(),
)
parser.disable_interspersed_args()
parser.version = get_pip_version()
# add the general options
gen_opts = cmdoptions.make_option_group(cmdoptions.general_group, parser)
parser.add_option_group(gen_opts)
# so the help formatter knows
parser.main = True # type: ignore
# create command listing for description
description = [""] + [
f"{name:27} {command_info.summary}"
for name, command_info in commands_dict.items()
]
parser.description = "\n".join(description)
return parser
def identify_python_interpreter(python: str) -> Optional[str]:
# If the named file exists, use it.
# If it's a directory, assume it's a virtual environment and
# look for the environment's Python executable.
if os.path.exists(python):
if os.path.isdir(python):
# bin/python for Unix, Scripts/python.exe for Windows
# Try both in case of odd cases like cygwin.
for exe in ("bin/python", "Scripts/python.exe"):
py = os.path.join(python, exe)
if os.path.exists(py):
return py
else:
return python
# Could not find the interpreter specified
return None
def parse_command(args: List[str]) -> Tuple[str, List[str]]:
parser = create_main_parser()
# Note: parser calls disable_interspersed_args(), so the result of this
# call is to split the initial args into the general options before the
# subcommand and everything else.
# For example:
# args: ['--timeout=5', 'install', '--user', 'INITools']
# general_options: ['--timeout==5']
# args_else: ['install', '--user', 'INITools']
general_options, args_else = parser.parse_args(args)
# --python
if general_options.python and "_PIP_RUNNING_IN_SUBPROCESS" not in os.environ:
# Re-invoke pip using the specified Python interpreter
interpreter = identify_python_interpreter(general_options.python)
if interpreter is None:
raise CommandError(
f"Could not locate Python interpreter {general_options.python}"
)
pip_cmd = [
interpreter,
get_runnable_pip(),
]
pip_cmd.extend(args)
# Set a flag so the child doesn't re-invoke itself, causing
# an infinite loop.
os.environ["_PIP_RUNNING_IN_SUBPROCESS"] = "1"
returncode = 0
try:
proc = subprocess.run(pip_cmd)
returncode = proc.returncode
except (subprocess.SubprocessError, OSError) as exc:
raise CommandError(f"Failed to run pip under {interpreter}: {exc}")
sys.exit(returncode)
# --version
if general_options.version:
sys.stdout.write(parser.version)
sys.stdout.write(os.linesep)
sys.exit()
# pip || pip help -> print_help()
if not args_else or (args_else[0] == "help" and len(args_else) == 1):
parser.print_help()
sys.exit()
# the subcommand name
cmd_name = args_else[0]
if cmd_name not in commands_dict:
guess = get_similar_commands(cmd_name)
msg = [f'unknown command "{cmd_name}"']
if guess:
msg.append(f'maybe you meant "{guess}"')
raise CommandError(" - ".join(msg))
# all the args without the subcommand
cmd_args = args[:]
cmd_args.remove(cmd_name)
return cmd_name, cmd_args