Skip to content

Fix meet_types for literal and instance #19605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType:
# Simplest case: join two types with the same base type (but
# potentially different arguments).

last_known_value = (
None if t.last_known_value != s.last_known_value else t.last_known_value
)

# Combine type arguments.
args: list[Type] = []
# N.B: We use zip instead of indexing because the lengths might have
Expand Down Expand Up @@ -104,10 +108,10 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType:
new_type = join_types(ta, sa, self)
if len(type_var.values) != 0 and new_type not in type_var.values:
self.seen_instances.pop()
return object_from_instance(t)
return object_from_instance(t, last_known_value=last_known_value)
if not is_subtype(new_type, type_var.upper_bound):
self.seen_instances.pop()
return object_from_instance(t)
return object_from_instance(t, last_known_value=last_known_value)
# TODO: contravariant case should use meet but pass seen instances as
# an argument to keep track of recursive checks.
elif type_var.variance in (INVARIANT, CONTRAVARIANT):
Expand All @@ -117,7 +121,7 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType:
new_type = ta
elif not is_equivalent(ta, sa):
self.seen_instances.pop()
return object_from_instance(t)
return object_from_instance(t, last_known_value=last_known_value)
else:
# If the types are different but equivalent, then an Any is involved
# so using a join in the contravariant case is also OK.
Expand All @@ -141,11 +145,17 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType:
new_type = join_types(ta, sa, self)
assert new_type is not None
args.append(new_type)
result: ProperType = Instance(t.type, args)
result: ProperType = Instance(t.type, args, last_known_value=last_known_value)
elif t.type.bases and is_proper_subtype(
t, s, subtype_context=SubtypeContext(ignore_type_params=True)
):
result = self.join_instances_via_supertype(t, s)
elif s.type.bases and is_proper_subtype(
s, t, subtype_context=SubtypeContext(ignore_type_params=True)
):
result = self.join_instances_via_supertype(s, t)
elif is_subtype(t, s, subtype_context=SubtypeContext(ignore_type_params=True)):
result = self.join_instances_via_supertype(t, s)
else:
# Now t is not a subtype of s, and t != s. Now s could be a subtype
# of t; alternatively, we need to find a common supertype. This works
Expand Down Expand Up @@ -621,13 +631,16 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
def visit_literal_type(self, t: LiteralType) -> ProperType:
if isinstance(self.s, LiteralType):
if t == self.s:
# E.g. Literal["x"], Literal["x"] -> Literal["x"]
return t
if self.s.fallback.type.is_enum and t.fallback.type.is_enum:
return mypy.typeops.make_simplified_union([self.s, t])
return join_types(self.s.fallback, t.fallback)
elif isinstance(self.s, Instance) and self.s.last_known_value == t:
return t
# E.g. Literal["x"], Literal["x"]? -> Literal["x"]?
return self.s
else:
# E.g. Literal["x"], Literal["y"]? -> str
return join_types(self.s, t.fallback)

def visit_partial_type(self, t: PartialType) -> ProperType:
Expand Down Expand Up @@ -848,10 +861,12 @@ def combine_arg_names(
return new_names


def object_from_instance(instance: Instance) -> Instance:
def object_from_instance(
instance: Instance, last_known_value: LiteralType | None = None
) -> Instance:
"""Construct the type 'builtins.object' from an instance type."""
# Use the fact that 'object' is always the last class in the mro.
res = Instance(instance.type.mro[-1], [])
res = Instance(instance.type.mro[-1], [], last_known_value=last_known_value)
return res


Expand Down
34 changes: 32 additions & 2 deletions mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,30 @@ def meet_types(s: Type, t: Type) -> ProperType:
t = get_proper_type(t)

if isinstance(s, Instance) and isinstance(t, Instance) and s.type == t.type:
# special casing for dealing with last known values
lkv: LiteralType | None

if s.last_known_value is None:
lkv = t.last_known_value
elif t.last_known_value is None:
lkv = s.last_known_value
else:
lkv_meet = meet_types(s.last_known_value, t.last_known_value)
if isinstance(lkv_meet, UninhabitedType):
lkv = None
elif isinstance(lkv_meet, LiteralType):
lkv = lkv_meet
else:
msg = (
f"Unexpected result: "
f"meet of {s.last_known_value=!s} and {t.last_known_value=!s} "
f"resulted in {lkv_meet!s}"
)
raise ValueError(msg)

t = t.copy_modified(last_known_value=lkv)
s = s.copy_modified(last_known_value=lkv)

# Code in checker.py should merge any extra_items where possible, so we
# should have only compatible extra_items here. We check this before
# the below subtype check, so that extra_attrs will not get erased.
Expand Down Expand Up @@ -1088,8 +1112,14 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
def visit_literal_type(self, t: LiteralType) -> ProperType:
if isinstance(self.s, LiteralType) and self.s == t:
return t
elif isinstance(self.s, Instance) and is_subtype(t.fallback, self.s):
return t
elif isinstance(self.s, Instance):
# if is_subtype(t.fallback, self.s):
# return t
if self.s.last_known_value is not None:
# meet(Literal["max"]?, Literal["max"]) -> Literal["max"]
# meet(Literal["sum"]?, Literal["max"]) -> Never
return meet_types(self.s.last_known_value, t)
return self.default(self.s)
else:
return self.default(self.s)

Expand Down
3 changes: 2 additions & 1 deletion mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None:
elif top is None:
candidate = bottom
elif is_subtype(bottom, top):
candidate = bottom
# Need to meet in case like Literal["x"]? <: T <: Literal["x"]
candidate = meet_types(bottom, top)
else:
candidate = None
return candidate
Expand Down
28 changes: 26 additions & 2 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,13 @@ def visit_instance(self, left: Instance) -> bool:
assert isinstance(erased, Instance)
t = erased
nominal = True
if self.proper_subtype and right.last_known_value is not None:
if left.last_known_value is None:
# E.g. str is not a proper subtype of Literal["x"]?
nominal = False
else:
# E.g. Literal[A]? <: Literal[B]? requires A <: B
nominal &= self._is_subtype(left.last_known_value, right.last_known_value)
if right.type.has_type_var_tuple_type:
# For variadic instances we simply find the correct type argument mappings,
# all the heavy lifting is done by the tuple subtyping.
Expand Down Expand Up @@ -629,8 +636,14 @@ def visit_instance(self, left: Instance) -> bool:
return True
if isinstance(item, Instance):
return is_named_instance(item, "builtins.object")
if isinstance(right, LiteralType) and left.last_known_value is not None:
return self._is_subtype(left.last_known_value, right)
if isinstance(right, LiteralType):
if self.proper_subtype:
# Instance types like Literal["sum"]? is *assignable* to Literal["sum"],
# but is not a proper subtype of it. (Literal["sum"]? is a gradual type,
# that is a proper subtype of str, and assignable to Literal["sum"].
return False
if left.last_known_value is not None:
return self._is_subtype(left.last_known_value, right)
if isinstance(right, FunctionLike):
# Special case: Instance can be a subtype of Callable / Overloaded.
call = find_member("__call__", left, left, is_operator=True)
Expand Down Expand Up @@ -965,6 +978,12 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool:
def visit_literal_type(self, left: LiteralType) -> bool:
if isinstance(self.right, LiteralType):
return left == self.right
elif (
isinstance(self.right, Instance)
and self.right.last_known_value is not None
and self.proper_subtype
):
return self._is_subtype(left, self.right.last_known_value)
else:
return self._is_subtype(left.fallback, self.right)

Expand Down Expand Up @@ -2127,6 +2146,11 @@ def covers_at_runtime(item: Type, supertype: Type) -> bool:
item = get_proper_type(item)
supertype = get_proper_type(supertype)

# Use last known value for Instance types, if available.
# This ensures that e.g. Literal["max"]? is covered by Literal["max"].
if isinstance(item, Instance) and item.last_known_value is not None:
item = item.last_known_value

# Since runtime type checks will ignore type arguments, erase the types.
if not (isinstance(supertype, FunctionLike) and supertype.is_type_obj()):
supertype = erase_type(supertype)
Expand Down
126 changes: 125 additions & 1 deletion mypy/test/testsubtypes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT
from mypy.subtypes import is_subtype
from mypy.subtypes import is_proper_subtype, is_subtype, restrict_subtype_away
from mypy.test.helpers import Suite
from mypy.test.typefixture import InterfaceTypeFixture, TypeFixture
from mypy.types import Instance, TupleType, Type, UninhabitedType, UnpackType
Expand Down Expand Up @@ -277,6 +277,74 @@ def test_type_var_tuple_unpacked_variable_length_tuple(self) -> None:
def test_fallback_not_subtype_of_tuple(self) -> None:
self.assert_not_subtype(self.fx.a, TupleType([self.fx.b], fallback=self.fx.a))

def test_literal(self) -> None:
str1 = self.fx.lit_str1
str2 = self.fx.lit_str2
str1_inst = self.fx.lit_str1_inst
str2_inst = self.fx.lit_str2_inst
str_type = self.fx.str_type

# other operand is the fallback type
# "x" ≲ str -> YES
# str ≲ "x" -> NO
# "x"? ≲ str -> YES
# str ≲ "x"? -> YES
self.assert_subtype(str1, str_type)
self.assert_not_subtype(str_type, str1)
self.assert_subtype(str1_inst, str_type)
self.assert_subtype(str_type, str1_inst)

# other operand is the same literal
# "x" ≲ "x" -> YES
# "x" ≲ "x"? -> YES
# "x"? ≲ "x" -> YES
# "x"? ≲ "x"? -> YES
self.assert_subtype(str1, str1)
self.assert_subtype(str1, str1_inst)
self.assert_subtype(str1_inst, str1)
self.assert_subtype(str1_inst, str1_inst)

# other operand is a different literal
# "x" ≲ "y" -> NO
# "x" ≲ "y"? -> YES
# "x"? ≲ "y" -> NO
# "x"? ≲ "y"? -> YES
self.assert_not_subtype(str1, str2)
self.assert_subtype(str1, str2_inst)
self.assert_not_subtype(str1_inst, str2)
self.assert_subtype(str1_inst, str2_inst)

# check proper subtyping
# other operand is the fallback type
# "x" <: str -> YES
# str <: "x" -> NO
# "x"? <: str -> YES
# str <: "x"? -> NO
self.assert_proper_subtype(str1, str_type)
self.assert_not_proper_subtype(str_type, str1)
self.assert_proper_subtype(str1_inst, str_type)
self.assert_not_proper_subtype(str_type, str1_inst)

# other operand is the same literal
# "x" <: "x" -> YES
# "x" <: "x"? -> YES
# "x"? <: "x" -> NO
# "x"? <: "x"? -> YES
self.assert_proper_subtype(str1, str1)
self.assert_proper_subtype(str1, str1_inst)
self.assert_not_proper_subtype(str1_inst, str1)
self.assert_proper_subtype(str1_inst, str1_inst)

# other operand is a different literal
# "x" <: "y" -> NO
# "x" <: "y"? -> NO
# "x"? <: "y" -> NO
# "x"? <: "y"? -> NO
self.assert_not_proper_subtype(str1, str2)
self.assert_not_proper_subtype(str1, str2_inst)
self.assert_not_proper_subtype(str1_inst, str2)
self.assert_not_proper_subtype(str1_inst, str2_inst)

# IDEA: Maybe add these test cases (they are tested pretty well in type
# checker tests already):
# * more interface subtyping test cases
Expand All @@ -287,6 +355,12 @@ def test_fallback_not_subtype_of_tuple(self) -> None:
# * any type
# * generic function types

def assert_proper_subtype(self, s: Type, t: Type) -> None:
assert is_proper_subtype(s, t), f"{s} not proper subtype of {t}"

def assert_not_proper_subtype(self, s: Type, t: Type) -> None:
assert not is_proper_subtype(s, t), f"{s} not proper subtype of {t}"

def assert_subtype(self, s: Type, t: Type) -> None:
assert is_subtype(s, t), f"{s} not subtype of {t}"

Expand All @@ -304,3 +378,53 @@ def assert_equivalent(self, s: Type, t: Type) -> None:
def assert_unrelated(self, s: Type, t: Type) -> None:
self.assert_not_subtype(s, t)
self.assert_not_subtype(t, s)


class RestrictionSuite(Suite):
# Tests for type restrictions "A - B", i.e. ``T <: A and not T <: B``.

def setUp(self) -> None:
self.fx = TypeFixture()

def assert_restriction(self, s: Type, t: Type, expected: Type) -> None:
actual = restrict_subtype_away(s, t)
msg = f"restrict_subtype_away({s}, {t}) == {{}} ({{}} expected)"
self.assertEqual(actual, expected, msg=msg.format(actual, expected))

def test_literal(self) -> None:
str1 = self.fx.lit_str1
str2 = self.fx.lit_str2
str1_inst = self.fx.lit_str1_inst
str2_inst = self.fx.lit_str2_inst
str_type = self.fx.str_type
uninhabited = self.fx.uninhabited

# other operand is the fallback type
# "x" - str -> Never
# str - "x" -> str
# "x"? - str -> Never
# str - "x"? -> Never
self.assert_restriction(str1, str_type, uninhabited)
self.assert_restriction(str_type, str1, str_type)
self.assert_restriction(str1_inst, str_type, uninhabited)
self.assert_restriction(str_type, str1_inst, uninhabited)

# other operand is the same literal
# "x" - "x" -> Never
# "x" - "x"? -> Never
# "x"? - "x" -> Never
# "x"? - "x"? -> Never
self.assert_restriction(str1, str1, uninhabited)
self.assert_restriction(str1, str1_inst, uninhabited)
self.assert_restriction(str1_inst, str1, uninhabited)
self.assert_restriction(str1_inst, str1_inst, uninhabited)

# other operand is a different literal
# "x" - "y" -> "x"
# "x" - "y"? -> Never
# "x"? - "y" -> "x"?
# "x"? - "y"? -> Never
self.assert_restriction(str1, str2, str1)
self.assert_restriction(str1, str2_inst, uninhabited)
self.assert_restriction(str1_inst, str2, str1_inst)
self.assert_restriction(str1_inst, str2_inst, uninhabited)
Loading
Loading