Skip to content

Fix Dependent types in TupledFunction #23615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,7 @@ class Definitions {
def TupleClass(using Context): ClassSymbol = TupleTypeRef.symbol.asClass
@tu lazy val Tuple_cons: Symbol = TupleClass.requiredMethod("*:")
@tu lazy val TupleModule: Symbol = requiredModule("scala.Tuple")
@tu lazy val Tuple_Elem: Symbol = TupleModule.requiredType("Elem")
@tu lazy val EmptyTupleClass: Symbol = requiredClass("scala.EmptyTuple")
@tu lazy val EmptyTupleModule: Symbol = requiredModule("scala.EmptyTuple")
@tu lazy val NonEmptyTupleTypeRef: TypeRef = requiredClassRef("scala.NonEmptyTuple")
Expand Down Expand Up @@ -1270,6 +1271,32 @@ class Definitions {
else None
}

// Pattern matcher of tuple selectors at the type level
// Matches Tuple.Elem[x, constant] and x._1.type
object TupleSelectorOf:
def unapply(tp: Type)(using Context): Option[(Type, Int)] =
if tp.isRef(Tuple_Elem) then
// get the arg infos for tuple elem
tp.argInfos match
case tupleType :: ConstantType(c) :: Nil
if c.isIntRange && (isTupleNType(tupleType) || tupleType.isRef(PairClass)) =>
Some((tupleType, c.intValue))
// we should probably report an error here, but its probably handled elsewhere
case _ => None
else
tp.dealias match
// very explicitly ONLY check for isTupleNType
// pair class doesn't have selector fields
case TermRef(tupleType, field) if isTupleNType(tupleType) =>
field match
case name: SimpleName =>
name.toString match
case s"_$id" =>
id.toIntOption.map(it => (tupleType, it - 1))
case _ => None
case _ => None
case _ => None

object ArrayOf {
def apply(elem: Type)(using Context): Type =
if (ctx.erasedTypes) JavaArrayType(elem)
Expand Down Expand Up @@ -1473,7 +1500,7 @@ class Definitions {
def patchStdLibClass(denot: ClassDenotation)(using Context): Unit =
// Do not patch the stdlib files if we explicitly disable it
// This is only to be used during the migration of the stdlib
if ctx.settings.YnoStdlibPatches.value then
if ctx.settings.YnoStdlibPatches.value then
return

def patch2(denot: ClassDenotation, patchCls: Symbol): Unit =
Expand Down
91 changes: 76 additions & 15 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import ast.tpd.*
import Synthesizer.*
import sbt.ExtractDependencies.*
import xsbti.api.DependencyContext.*
import dotty.tools.dotc.core.Definitions.MaxTupleArity

/** Synthesize terms for special classes */
class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
Expand Down Expand Up @@ -105,29 +106,89 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
val synthesizedTupleFunction: SpecialHandler = (formal, span) =>
formal match
case AppliedType(_, funArgs @ fun :: tupled :: Nil) =>
def doesFunctionTupleInto(baseFun: Type, actualArgs: List[Type],
actualRet: Type, tupled: Type) =
tupled =:= constructDependentTupleType(actualArgs, actualRet, defn.isContextFunctionType(baseFun))
def doesFunctionUntupleTo(baseFun: Type, actualArgs: List[Type],
actualRet: Type, untupled: Type) =
untupled =:= untupleDependentTupleType(actualArgs, actualRet, defn.isContextFunctionType(baseFun))

def functionTypeEqual(baseFun: Type, actualArgs: List[Type],
actualRet: Type, expected: Type) =
expected =:= defn.FunctionNOf(actualArgs, actualRet,
defn.isContextFunctionType(baseFun))
expected =:= defn.FunctionNOf(actualArgs, actualRet, defn.isContextFunctionType(baseFun))
def untupleDependentTupleType(args: List[Type], ret: Type, contextual: Boolean): Type =
val methodKind = if contextual then ContextualMethodType else MethodType

val arity = args.length
methodKind(args.indices.map(nme.syntheticParamName).toList)(
mt => args,
mt =>
val tpeMap = new TypeMap:
def apply(tp: Type): Type =
tp match
case defn.TupleSelectorOf(TermParamRef(_, paramNum), fieldNum) =>
if fieldNum >= arity then
NoType
else
mt.paramRefs(fieldNum)
case _ => mapOver(tp)
tpeMap(ret)
)

def constructDependentTupleType(args: List[Type], ret: Type, contextual: Boolean): Type =
val methodKind = if contextual then ContextualMethodType else MethodType

methodKind(List(nme.syntheticParamName(0)))(
mt => List(defn.tupleType(args)),
mt =>
val tpeMap = new TypeMap:
def apply(tp: Type): Type =
tp match
case TermParamRef(binder, paramNum) =>
mt.paramRefs(0).select(nme.selectorName(paramNum))
case _ =>
mapOver(tp)
tpeMap(ret)
).toFunctionType()

val arity: Int =
if defn.isFunctionNType(fun) then
// TupledFunction[(...) => R, ?]
fun.functionArgInfos match
case funArgs :+ funRet
if functionTypeEqual(fun, defn.tupleType(funArgs) :: Nil, funRet, tupled) =>
// TupledFunction[(...funArgs...) => funRet, ?]
funArgs.size
case _ => -1
// dont use functionArgInfos it dealiases and drops dependents

fun.dealias match
case defn.RefinedFunctionOf(method: MethodType) if doesFunctionTupleInto(fun, method.paramInfos, method.resType, tupled) =>
method.paramInfos.size
// poly types are unsupported
case defn.RefinedFunctionOf(_) => -1
case _ =>
fun.functionArgInfos match
case funArgs :+ funRet
if functionTypeEqual(fun, defn.tupleType(funArgs) :: Nil, funRet, tupled) =>
// TupledFunction[(...funArgs...) => funRet, ?]
funArgs.size
case _ => -1
else if defn.isFunctionNType(tupled) then
// TupledFunction[?, (...) => R]
tupled.functionArgInfos match
case tupledArgs :: funRet :: Nil =>
tupledArgs.tupleElementTypes match
case Some(funArgs) if functionTypeEqual(tupled, funArgs, funRet, fun) =>
// TupledFunction[?, ((...funArgs...)) => funRet]
funArgs.size
tupled.dealias match
case defn.RefinedFunctionOf(method: MethodType) =>
method.argInfos match
case tupledArgs :: funRet :: Nil =>
// TupledFunction[?, ((...)) => R]
tupledArgs.tupleElementTypes match
case Some(args) if doesFunctionUntupleTo(tupled, args, funRet, fun) =>
args.size
case _ => -1
case _ => -1
case _ =>
tupled.functionArgInfos match
case tupledArgs :: funRet :: Nil =>
// TupledFunction[?, ((...)) => R]
tupledArgs.tupleElementTypes match
case Some(args) if functionTypeEqual(tupled, args, funRet, fun) =>
args.size
case _ => -1
case _ => -1
case _ => -1
else
// TupledFunction[?, ?]
-1
Expand Down
4 changes: 4 additions & 0 deletions tests/neg/i21808a.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- [E172] Type Error: tests/neg/i21808a.scala:9:63 ---------------------------------------------------------------------
9 | summon[TupledFunction[(x: T, y: T) => x.type, ((T, T)) => T]] // error
| ^
| (x: Test.T, y: Test.T) => x.type cannot be tupled as ((Test.T, Test.T)) => Test.T
10 changes: 10 additions & 0 deletions tests/neg/i21808a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
//> using options -experimental

import scala.util.TupledFunction
import scala.util.NotGiven

object Test {
type T

summon[TupledFunction[(x: T, y: T) => x.type, ((T, T)) => T]] // error
}
4 changes: 4 additions & 0 deletions tests/neg/i21808b.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- [E172] Type Error: tests/neg/i21808b.scala:9:63 ---------------------------------------------------------------------
9 | summon[TupledFunction[(T, T) => T, (x: (T, T)) => x._1.type]] // error
| ^
| (Test.T, Test.T) => Test.T cannot be tupled as (x: (Test.T, Test.T)) => (x._1 : Test.T)
10 changes: 10 additions & 0 deletions tests/neg/i21808b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
//> using options -experimental

import scala.util.TupledFunction
import scala.util.NotGiven

object Test {
type T

summon[TupledFunction[(T, T) => T, (x: (T, T)) => x._1.type]] // error
}
10 changes: 10 additions & 0 deletions tests/pos/i21808.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
//> using options -experimental

import scala.util.TupledFunction
import scala.util.NotGiven

object Test {
type T

summon[TupledFunction[(x: T, y: T) => x.type, (x: (T, T)) => x._1.type]]
}
Loading