Skip to content

Fix --strict-equality for iteratively visited code #19635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions mypy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ def filtered_errors(self) -> list[ErrorInfo]:

class IterationDependentErrors:
"""An `IterationDependentErrors` instance serves to collect the `unreachable`,
`redundant-expr`, and `redundant-casts` errors, as well as the revealed types,
handled by the individual `IterationErrorWatcher` instances sequentially applied to
the same code section."""
`redundant-expr`, and `redundant-casts` errors, as well as the revealed types and
non-overlapping types, handled by the individual `IterationErrorWatcher` instances
sequentially applied to the same code section."""

# One set of `unreachable`, `redundant-expr`, and `redundant-casts` errors per
# iteration step. Meaning of the tuple items: ErrorCode, message, line, column,
Expand All @@ -248,9 +248,16 @@ class IterationDependentErrors:
# end_line, end_column:
revealed_types: dict[tuple[int, int, int | None, int | None], list[Type]]

# One dictionary of non-overlapping types per iteration step. Meaning of the key
# tuple items: line, column, end_line, end_column, kind:
nonoverlapping_types: list[
dict[tuple[int, int, int | None, int | None, str], tuple[Type, Type]],
]

def __init__(self) -> None:
self.uselessness_errors = []
self.unreachable_lines = []
self.nonoverlapping_types = []
self.revealed_types = defaultdict(list)

def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCode]]:
Expand All @@ -270,6 +277,36 @@ def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCod
context.end_column = error_info[5]
yield error_info[1], context, error_info[0]

def yield_nonoverlapping_types(
self,
) -> Iterator[tuple[tuple[list[Type], list[Type]], str, Context]]:
"""Report expressions were non-overlapping types were detected for all iterations
were the expression was reachable."""

selected = set()
for candidate in set(chain(*self.nonoverlapping_types)):
if all(
(candidate in nonoverlap) or (candidate[0] in lines)
for nonoverlap, lines in zip(self.nonoverlapping_types, self.unreachable_lines)
):
selected.add(candidate)

persistent_nonoverlaps: dict[
tuple[int, int, int | None, int | None, str], tuple[list[Type], list[Type]]
] = defaultdict(lambda: ([], []))
for nonoverlaps in self.nonoverlapping_types:
for candidate, (left, right) in nonoverlaps.items():
if candidate in selected:
types = persistent_nonoverlaps[candidate]
types[0].append(left)
types[1].append(right)

for error_info, types in persistent_nonoverlaps.items():
context = Context(line=error_info[0], column=error_info[1])
context.end_line = error_info[2]
context.end_column = error_info[3]
yield (types[0], types[1]), error_info[4], context

def yield_revealed_type_infos(self) -> Iterator[tuple[list[Type], Context]]:
"""Yield all types revealed in at least one iteration step."""

Expand All @@ -282,8 +319,9 @@ def yield_revealed_type_infos(self) -> Iterator[tuple[list[Type], Context]]:

class IterationErrorWatcher(ErrorWatcher):
"""Error watcher that filters and separately collects `unreachable` errors,
`redundant-expr` and `redundant-casts` errors, and revealed types when analysing
code sections iteratively to help avoid making too-hasty reports."""
`redundant-expr` and `redundant-casts` errors, and revealed types and
non-overlapping types when analysing code sections iteratively to help avoid
making too-hasty reports."""

iteration_dependent_errors: IterationDependentErrors

Expand All @@ -304,6 +342,7 @@ def __init__(
)
self.iteration_dependent_errors = iteration_dependent_errors
iteration_dependent_errors.uselessness_errors.append(set())
iteration_dependent_errors.nonoverlapping_types.append({})
iteration_dependent_errors.unreachable_lines.append(set())

def on_error(self, file: str, info: ErrorInfo) -> bool:
Expand Down
20 changes: 19 additions & 1 deletion mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,21 @@ def incompatible_typevar_value(
)

def dangerous_comparison(self, left: Type, right: Type, kind: str, ctx: Context) -> None:

# In loops (and similar cases), the same expression might be analysed multiple
# times and thereby confronted with different types. We only want to raise a
# `comparison-overlap` error if it occurs in all cases and therefore collect the
# respective types of the current iteration here so that we can report the error
# later if it is persistent over all iteration steps:
for watcher in self.errors.get_watchers():
if watcher._filter:
break
if isinstance(watcher, IterationErrorWatcher):
watcher.iteration_dependent_errors.nonoverlapping_types[-1][
(ctx.line, ctx.column, ctx.end_line, ctx.end_column, kind)
] = (left, right)
return

left_str = "element" if kind == "container" else "left operand"
right_str = "container item" if kind == "container" else "right operand"
message = "Non-overlapping {} check ({} type: {}, {} type: {})"
Expand Down Expand Up @@ -2511,8 +2526,11 @@ def match_statement_inexhaustive_match(self, typ: Type, context: Context) -> Non
def iteration_dependent_errors(self, iter_errors: IterationDependentErrors) -> None:
for error_info in iter_errors.yield_uselessness_error_infos():
self.fail(*error_info[:2], code=error_info[2])
msu = mypy.typeops.make_simplified_union
for nonoverlaps, kind, context in iter_errors.yield_nonoverlapping_types():
self.dangerous_comparison(msu(nonoverlaps[0]), msu(nonoverlaps[1]), kind, context)
for types, context in iter_errors.yield_revealed_type_infos():
self.reveal_type(mypy.typeops.make_simplified_union(types), context)
self.reveal_type(msu(types), context)


def quote_type_string(type_string: str) -> str:
Expand Down
35 changes: 35 additions & 0 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -2446,6 +2446,41 @@ while x is not None and b():
x = f()
[builtins fixtures/primitives.pyi]

[case testAvoidFalseNonOverlappingEqualityCheckInLoop1]
# flags: --allow-redefinition-new --local-partial-types --strict-equality

x = 1
while True:
if x == str():
break
x = str()
if x == int(): # E: Non-overlapping equality check (left operand type: "str", right operand type: "int")
break
[builtins fixtures/primitives.pyi]

[case testAvoidFalseNonOverlappingEqualityCheckInLoop2]
# flags: --allow-redefinition-new --local-partial-types --strict-equality

class A: ...
class B: ...
class C: ...

x = A()
while True:
if x == C(): # E: Non-overlapping equality check (left operand type: "Union[A, B]", right operand type: "C")
break
x = B()
[builtins fixtures/primitives.pyi]

[case testAvoidFalseNonOverlappingEqualityCheckInLoop3]
# flags: --strict-equality

for y in [1.0]:
if y is not None or y != "None":
...

[builtins fixtures/primitives.pyi]

[case testNarrowPromotionsInsideUnions1]

from typing import Union
Expand Down
Loading