From 9bf58090c704a59d8735874c565200758bcea666 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Wed, 14 Dec 2016 22:21:51 +0100 Subject: Change by-name pattern matching. New implementation following the scheme outlined in #1790. --- .../src/dotty/tools/dotc/typer/Applications.scala | 62 +++++++++++++++------- 1 file changed, 43 insertions(+), 19 deletions(-) (limited to 'compiler/src/dotty/tools/dotc/typer/Applications.scala') diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 11121e1f3..eca4df617 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -32,16 +32,37 @@ import reporting.diagnostic.Message object Applications { import tpd._ + def extractorMember(tp: Type, name: Name)(implicit ctx: Context) = { + def isPossibleExtractorType(tp: Type) = tp match { + case _: MethodType | _: PolyType => false + case _ => true + } + tp.member(name).suchThat(d => isPossibleExtractorType(d.info)) + } + def extractorMemberType(tp: Type, name: Name, errorPos: Position = NoPosition)(implicit ctx: Context) = { - val ref = tp.member(name).suchThat(_.info.isParameterless) + val ref = extractorMember(tp, name) if (ref.isOverloaded) errorType(i"Overloaded reference to $ref is not allowed in extractor", errorPos) - else if (ref.info.isInstanceOf[PolyType]) - errorType(i"Reference to polymorphic $ref: ${ref.info} is not allowed in extractor", errorPos) - else - ref.info.widenExpr.dealias + ref.info.widenExpr.dealias } + /** Does `tp` fit the "product match" conditions as an unapply result type? + * This is the case of `tp` is a subtype of a ProductN class and `tp` has a + * parameterless `isDefined` member of result type `Boolean`. + */ + def isProductMatch(tp: Type, errorPos: Position = NoPosition)(implicit ctx: Context) = + extractorMemberType(tp, nme.isDefined, errorPos).isRef(defn.BooleanClass) && + defn.isProductSubType(tp) + + /** Does `tp` fit the "get match" conditions as an unapply result type? + * This is the case of `tp` has a `get` member as well as a + * parameterless `isDefined` member of result type `Boolean`. + */ + def isGetMatch(tp: Type, errorPos: Position = NoPosition)(implicit ctx: Context) = + extractorMemberType(tp, nme.isEmpty, errorPos).isRef(defn.BooleanClass) && + extractorMemberType(tp, nme.get, errorPos).exists + def productSelectorTypes(tp: Type, errorPos: Position = NoPosition)(implicit ctx: Context): List[Type] = { val sels = for (n <- Iterator.from(0)) yield extractorMemberType(tp, nme.selectorName(n), errorPos) sels.takeWhile(_.exists).toList @@ -62,24 +83,27 @@ object Applications { def unapplyArgs(unapplyResult: Type, unapplyFn: Tree, args: List[untpd.Tree], pos: Position = NoPosition)(implicit ctx: Context): List[Type] = { def seqSelector = defn.RepeatedParamType.appliedTo(unapplyResult.elemType :: Nil) - def getTp = extractorMemberType(unapplyResult, nme.get, pos) - // println(s"unapply $unapplyResult ${extractorMemberType(unapplyResult, nme.isDefined)}") - if (extractorMemberType(unapplyResult, nme.isDefined, pos) isRef defn.BooleanClass) { - if (getTp.exists) - if (unapplyFn.symbol.name == nme.unapplySeq) { - val seqArg = boundsToHi(getTp.elemType) - if (seqArg.exists) return args map Function.const(seqArg) - } - else return getUnapplySelectors(getTp, args, pos) - else if (defn.isProductSubType(unapplyResult)) return productSelectorTypes(unapplyResult, pos) - } - if (unapplyResult derivesFrom defn.SeqClass) seqSelector :: Nil - else if (unapplyResult isRef defn.BooleanClass) Nil - else { + def fail = { ctx.error(i"$unapplyResult is not a valid result type of an unapply method of an extractor", pos) Nil } + + // println(s"unapply $unapplyResult ${extractorMemberType(unapplyResult, nme.isDefined)}") + if (isProductMatch(unapplyResult)) + productSelectorTypes(unapplyResult) + else if (isGetMatch(unapplyResult)) { + val getTp = extractorMemberType(unapplyResult, nme.get, pos) + if (unapplyFn.symbol.name == nme.unapplySeq) { + val seqArg = boundsToHi(getTp.elemType) + if (seqArg.exists) args.map(Function.const(seqArg)) + else fail + } + else getUnapplySelectors(getTp, args, pos) + } + else if (unapplyResult derivesFrom defn.SeqClass) seqSelector :: Nil + else if (unapplyResult isRef defn.BooleanClass) Nil + else fail } def wrapDefs(defs: mutable.ListBuffer[Tree], tree: Tree)(implicit ctx: Context): Tree = -- cgit v1.2.3 From b1553cb9c5894e2e91925a67afd2c986675e5c46 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Thu, 15 Dec 2016 15:21:26 +0100 Subject: Implement new rules for name-based pattern matching This implements the rules laid down in #1805. --- compiler/src/dotty/tools/dotc/ast/Desugar.scala | 4 +- .../src/dotty/tools/dotc/core/Definitions.scala | 10 ++++- compiler/src/dotty/tools/dotc/core/StdNames.scala | 1 - .../tools/dotc/transform/PatternMatcher.scala | 9 +++-- .../src/dotty/tools/dotc/typer/Applications.scala | 44 +++++++++++++--------- tests/pos/Patterns.scala | 28 ++++++++++++++ 6 files changed, 70 insertions(+), 26 deletions(-) (limited to 'compiler/src/dotty/tools/dotc/typer/Applications.scala') diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 15cb0b665..11f8b81eb 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -24,7 +24,6 @@ object desugar { /** Names of methods that are added unconditionally to case classes */ def isDesugaredCaseClassMethodName(name: Name)(implicit ctx: Context): Boolean = - name == nme.isDefined || name == nme.copy || name == nme.productArity || name.isSelectorName @@ -343,7 +342,6 @@ object desugar { if (isCaseClass) { def syntheticProperty(name: TermName, rhs: Tree) = DefDef(name, Nil, Nil, TypeTree(), rhs).withMods(synthetic) - val isDefinedMeth = syntheticProperty(nme.isDefined, Literal(Constant(true))) val caseParams = constrVparamss.head.toArray val productElemMeths = for (i <- 0 until arity) yield syntheticProperty(nme.selectorName(i), Select(This(EmptyTypeIdent), caseParams(i).name)) @@ -369,7 +367,7 @@ object desugar { DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree(), creatorExpr) .withMods(synthetic) :: Nil } - copyMeths ::: isDefinedMeth :: productElemMeths.toList + copyMeths ::: productElemMeths.toList } else Nil diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 1a7c62b30..9759e39fc 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -675,7 +675,7 @@ class Definitions { private def isVarArityClass(cls: Symbol, prefix: Name) = { val name = scalaClassName(cls) - name.startsWith(prefix) && + name.startsWith(prefix) && name.length > prefix.length && name.drop(prefix.length).forall(_.isDigit) } @@ -737,6 +737,14 @@ class Definitions { def isProductSubType(tp: Type)(implicit ctx: Context) = (tp derivesFrom ProductType.symbol) && tp.baseClasses.exists(isProductClass) + def productArity(tp: Type)(implicit ctx: Context) = + if (tp derivesFrom ProductType.symbol) + tp.baseClasses.find(isProductClass) match { + case Some(prod) => prod.typeParams.length + case None => -1 + } + else -1 + def isFunctionType(tp: Type)(implicit ctx: Context) = isFunctionClass(tp.dealias.typeSymbol) && { val arity = functionArity(tp) diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 741ff8b1f..e71893c1e 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -424,7 +424,6 @@ object StdNames { val info: N = "info" val inlinedEquals: N = "inlinedEquals" val isArray: N = "isArray" - val isDefined: N = "isDefined" val isDefinedAt: N = "isDefinedAt" val isDefinedAtImpl: N = "$isDefinedAt" val isEmpty: N = "isEmpty" diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index 1252781e8..181dfccd9 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -235,7 +235,8 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer { // next: MatchMonad[U] // returns MatchMonad[U] def flatMap(prev: Tree, b: Symbol, next: Tree): Tree = { - if (isProductMatch(prev.tpe)) { + val resultArity = defn.productArity(b.info) + if (isProductMatch(prev.tpe, resultArity)) { val nullCheck: Tree = prev.select(defn.Object_ne).appliedTo(Literal(Constant(null))) ifThenElseZero( nullCheck, @@ -1429,7 +1430,7 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer { def resultInMonad = if (aligner.isBool) defn.UnitType - else if (isProductMatch(resultType)) resultType + else if (isProductMatch(resultType, aligner.prodArity)) resultType else if (isGetMatch(resultType)) extractorMemberType(resultType, nme.get) else resultType @@ -1630,7 +1631,7 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer { ref(binder) :: Nil } else if ((aligner.isSingle && aligner.extractor.prodArity == 1) && - !isProductMatch(binderTypeTested) && isGetMatch(binderTypeTested)) + !isProductMatch(binderTypeTested, aligner.prodArity) && isGetMatch(binderTypeTested)) List(ref(binder)) else subPatRefs(binder) @@ -1885,7 +1886,7 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer { else if (result.classSymbol is Flags.CaseClass) result.decls.filter(x => x.is(Flags.CaseAccessor) && x.is(Flags.Method)).map(_.info).toList else result.select(nme.get) :: Nil )*/ - if (isProductMatch(resultType)) productSelectorTypes(resultType) + if (isProductMatch(resultType, args.length)) productSelectorTypes(resultType) else if (isGetMatch(resultType)) getUnapplySelectors(resultOfGet, args) else if (resultType isRef defn.BooleanClass) Nil else { diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index eca4df617..d34804865 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -47,13 +47,13 @@ object Applications { ref.info.widenExpr.dealias } - /** Does `tp` fit the "product match" conditions as an unapply result type? - * This is the case of `tp` is a subtype of a ProductN class and `tp` has a - * parameterless `isDefined` member of result type `Boolean`. + /** Does `tp` fit the "product match" conditions as an unapply result type + * for a pattern with `numArgs` subpatterns> + * This is the case of `tp` is a subtype of the Product class. */ - def isProductMatch(tp: Type, errorPos: Position = NoPosition)(implicit ctx: Context) = - extractorMemberType(tp, nme.isDefined, errorPos).isRef(defn.BooleanClass) && - defn.isProductSubType(tp) + def isProductMatch(tp: Type, numArgs: Int)(implicit ctx: Context) = + 0 <= numArgs && numArgs <= Definitions.MaxTupleArity && + tp.derivesFrom(defn.ProductNType(numArgs).typeSymbol) /** Does `tp` fit the "get match" conditions as an unapply result type? * This is the case of `tp` has a `get` member as well as a @@ -82,28 +82,38 @@ object Applications { def unapplyArgs(unapplyResult: Type, unapplyFn: Tree, args: List[untpd.Tree], pos: Position = NoPosition)(implicit ctx: Context): List[Type] = { + val unapplyName = unapplyFn.symbol.name def seqSelector = defn.RepeatedParamType.appliedTo(unapplyResult.elemType :: Nil) + def getTp = extractorMemberType(unapplyResult, nme.get, pos) def fail = { - ctx.error(i"$unapplyResult is not a valid result type of an unapply method of an extractor", pos) + ctx.error(i"$unapplyResult is not a valid result type of an $unapplyName method of an extractor", pos) Nil } - // println(s"unapply $unapplyResult ${extractorMemberType(unapplyResult, nme.isDefined)}") - if (isProductMatch(unapplyResult)) - productSelectorTypes(unapplyResult) - else if (isGetMatch(unapplyResult)) { - val getTp = extractorMemberType(unapplyResult, nme.get, pos) - if (unapplyFn.symbol.name == nme.unapplySeq) { + if (unapplyName == nme.unapplySeq) { + if (unapplyResult derivesFrom defn.SeqClass) seqSelector :: Nil + else if (isGetMatch(unapplyResult, pos)) { val seqArg = boundsToHi(getTp.elemType) if (seqArg.exists) args.map(Function.const(seqArg)) else fail } - else getUnapplySelectors(getTp, args, pos) + else fail + } + else { + assert(unapplyName == nme.unapply) + if (isProductMatch(unapplyResult, args.length)) + productSelectorTypes(unapplyResult) + else if (isGetMatch(unapplyResult, pos)) + getUnapplySelectors(getTp, args, pos) + else if (unapplyResult isRef defn.BooleanClass) + Nil + else if (defn.isProductSubType(unapplyResult)) + productSelectorTypes(unapplyResult) + // this will cause a "wrong number of arguments in pattern" error later on, + // which is better than the message in `fail`. + else fail } - else if (unapplyResult derivesFrom defn.SeqClass) seqSelector :: Nil - else if (unapplyResult isRef defn.BooleanClass) Nil - else fail } def wrapDefs(defs: mutable.ListBuffer[Tree], tree: Tree)(implicit ctx: Context): Tree = diff --git a/tests/pos/Patterns.scala b/tests/pos/Patterns.scala index aa369a77b..fd0d7e97a 100644 --- a/tests/pos/Patterns.scala +++ b/tests/pos/Patterns.scala @@ -108,3 +108,31 @@ object NestedPattern { val xss: List[List[String]] = ??? val List(List(x)) = xss } + +// Tricky case (exercised by Scala parser combinators) where we use +// both get/isEmpty and product-based pattern matching in different +// matches on the same types. +object ProductAndGet { + + trait Result[+T] + case class Success[+T](in: String, x: T) extends Result[T] { + def isEmpty = false + def get: T = x + } + case class Failure[+T](in: String, msg: String) extends Result[T] { + def isEmpty = false + def get: String = msg + } + + val r: Result[Int] = ??? + + r match { + case Success(in, x) => x + case Failure(in, msg) => -1 + } + + r match { + case Success(x) => x + case Failure(msg) => -1 + } +} -- cgit v1.2.3