diff options
author | Adriaan Moors <adriaan.moors@epfl.ch> | 2011-11-25 17:00:16 +0100 |
---|---|---|
committer | Adriaan Moors <adriaanm@gmail.com> | 2011-12-24 17:36:36 +0100 |
commit | fc0c123e3560da190a3daae35214c2be50fd59e6 (patch) | |
tree | 103617341203a282354b74328997a4b7b9f6dbf4 /src | |
parent | 5ee9a14a489c6e56c331914e9db258c0473d4d23 (diff) | |
download | scala-fc0c123e3560da190a3daae35214c2be50fd59e6.tar.gz scala-fc0c123e3560da190a3daae35214c2be50fd59e6.tar.bz2 scala-fc0c123e3560da190a3daae35214c2be50fd59e6.zip |
[vpm] unapplyProd: faster matching for case classes
behold the mythical unapplyProd: it does not exist, yet it promises to speed up pattern matching on case classes
instead of calling the synthetic unapply/unapplySeq, we don't call the mythical synthetic unapplyProd,
since -- if it existed -- it would be the identity anyway for case classes
eventually, we will allow user-defined unapplyProd's, which should give you almost the same speed as case class matching
for user-defined extractors (i.e., you don't have to wrap in an option, just return something on which we can select _i for i = 1 to N, unless it is null, which indicates match failure)
still need to figure out a way to derive the types for the subpatterns, without requiring you to wrap your result in a ProductN
unapplyProd support for vararg case classes
using caseFieldAccessors instead of synthetic _i
now the compiler bootstraps again, and after this optimization, quick.lib overhead is 70%, quick.comp is 50%
(compiling with a locker built using -Yvirtpatmat, and itself generating code for -Yvirtpatmat)
before the optimization, I think the overhead for quick.comp was close to 100% in this scenario
more robust tupleSel for case classes
TODO:
- pos/t602 -- clean up after type inference as in fromCaseClassUnapply
- run/pf-catch -- implement new-style orElse for partial function in uncurry
Diffstat (limited to 'src')
-rw-r--r-- | src/compiler/scala/tools/nsc/typechecker/PatMatVirtualiser.scala | 262 |
1 files changed, 178 insertions, 84 deletions
diff --git a/src/compiler/scala/tools/nsc/typechecker/PatMatVirtualiser.scala b/src/compiler/scala/tools/nsc/typechecker/PatMatVirtualiser.scala index 23d855f7b3..c04e4796c4 100644 --- a/src/compiler/scala/tools/nsc/typechecker/PatMatVirtualiser.scala +++ b/src/compiler/scala/tools/nsc/typechecker/PatMatVirtualiser.scala @@ -136,7 +136,7 @@ trait PatMatVirtualiser extends ast.TreeDSL { self: Analyzer => def translateExtractorPattern(extractor: ExtractorCall): TranslationStep = { if (!extractor.isTyped) throw new TypeError(pos, "Could not typecheck extractor call: "+ extractor) - if (extractor.resultInMonad == ErrorType) throw new TypeError(pos, "Unsupported extractor type: "+ extractor.tpe) + // if (extractor.resultInMonad == ErrorType) throw new TypeError(pos, "Unsupported extractor type: "+ extractor.tpe) // must use type `tp`, which is provided by extractor's result, not the type expected by binder, // as b.info may be based on a Typed type ascription, which has not been taken into account yet by the translation @@ -290,13 +290,16 @@ trait PatMatVirtualiser extends ast.TreeDSL { self: Analyzer => // helper methods: they analyze types and trees in isolation, but they are not (directly) concerned with the structure of the overall translation object ExtractorCall { - def apply(unfun: Tree, args: List[Tree]): ExtractorCall = new ExtractorCall(unfun, args) + def apply(unfun: Tree, args: List[Tree]): ExtractorCall = new ExtractorCallRegular(unfun, args) + def fromCaseClass(fun: Tree, args: List[Tree]): Option[ExtractorCall] = Some(new ExtractorCallProd(fun, args)) + + // THE PRINCIPLED SLOW PATH -- NOT USED // generate a call to the (synthetically generated) extractor of a case class // NOTE: it's an apply, not a select, since in general an extractor call may have multiple argument lists (including an implicit one) // that we need to preserve, so we supply the scrutinee as Ident(nme.SELECTOR_DUMMY), // and replace that dummy by a reference to the actual binder in translateExtractorPattern - def fromCaseClass(fun: Tree, args: List[Tree]): Option[ExtractorCall] = { + def fromCaseClassUnapply(fun: Tree, args: List[Tree]): Option[ExtractorCall] = { // TODO: can we rework the typer so we don't have to do all this twice? // undo rewrite performed in (5) of adapt val orig = fun match {case tpt: TypeTree => tpt.original case _ => fun} @@ -342,25 +345,20 @@ trait PatMatVirtualiser extends ast.TreeDSL { self: Analyzer => } } - class ExtractorCall(extractorCallIncludingDummy: Tree, val args: List[Tree]) { - private lazy val Some(Apply(extractorCall, _)) = extractorCallIncludingDummy.find{ case Apply(_, List(Ident(nme.SELECTOR_DUMMY))) => true case _ => false } + abstract class ExtractorCall(val args: List[Tree]) { + val nbSubPats = args.length - def tpe = extractorCall.tpe - def isTyped = (tpe ne NoType) && extractorCall.isTyped - def resultType = tpe.finalResultType - def paramType = tpe.paramTypes.head + // everything okay, captain? + def isTyped : Boolean - // what's the extractor's result type in the monad? - // turn an extractor's result type into something `monadTypeToSubPatTypesAndRefs` understands - lazy val resultInMonad: Type = if(!hasLength(tpe.paramTypes, 1)) ErrorType else { - if (resultType.typeSymbol == BooleanClass) UnitClass.tpe - else { - val monadArgs = resultType.baseType(matchingMonadType.typeSymbol).typeArgs - // assert(monadArgs.length == 1, "unhandled extractor type: "+ extractorTp) // TODO: overloaded unapply?? - if(monadArgs.length == 1) monadArgs(0) - else ErrorType - } - } + def isSeq: Boolean + lazy val lastIsStar = (nbSubPats > 0) && treeInfo.isStar(args.last) + + // to which type should the previous binder be casted? + def paramType : Type + + // binder has been casted to paramType if necessary + def treeMaker(binder: Symbol, pos: Position): TreeMaker // `subPatBinders` are the variables bound by this pattern in the following patterns // subPatBinders are replaced by references to the relevant part of the extractor's result (tuple component, seq element, the result as-is) @@ -374,15 +372,6 @@ trait PatMatVirtualiser extends ast.TreeDSL { self: Analyzer => case bp => bp } - def isSeq = extractorCall.symbol.name == nme.unapplySeq - lazy val nbSubPats = args.length - lazy val lastIsStar = (nbSubPats > 0) && treeInfo.isStar(args.last) - - // the types for the binders corresponding to my subpatterns - // subPatTypes != args map (_.tpe) since the args may have more specific types than the constructor's parameter types - // replace last type (of shape Seq[A]) with RepeatedParam[A] so that formalTypes will - // repeat the last argument type to align the formals with the number of arguments - // require (nbSubPats > 0 && (!lastIsStar || isSeq)) def subPatTypes: List[Type] = if(isSeq) { val TypeRef(pre, SeqClass, args) = seqTp @@ -390,6 +379,128 @@ trait PatMatVirtualiser extends ast.TreeDSL { self: Analyzer => formalTypes(rawSubPatTypes.init :+ typeRef(pre, RepeatedParamClass, args), nbSubPats) } else rawSubPatTypes + protected def rawSubPatTypes: List[Type] + + protected def seqTp = rawSubPatTypes.last baseType SeqClass + protected def seqLenCmp = rawSubPatTypes.last member nme.lengthCompare + protected lazy val firstIndexingBinder = rawSubPatTypes.length - 1 // rawSubPatTypes.last is the Seq, thus there are `rawSubPatTypes.length - 1` non-seq elements in the tuple + protected lazy val lastIndexingBinder = if(lastIsStar) nbSubPats-2 else nbSubPats-1 + protected lazy val expectedLength = lastIndexingBinder - firstIndexingBinder + 1 + protected lazy val minLenToCheck = if(lastIsStar) 1 else 0 + protected def seqTree(binder: Symbol) = tupleSel(binder)(firstIndexingBinder+1) + protected def tupleSel(binder: Symbol)(i: Int): Tree = pmgen.tupleSel(binder)(i) + + // the trees that select the subpatterns on the extractor's result, referenced by `binder` + // require isSeq + protected def subPatRefsSeq(binder: Symbol): List[Tree] = { + // only relevant if isSeq: (here to avoid capturing too much in the returned closure) + val indexingIndices = (0 to (lastIndexingBinder-firstIndexingBinder)) + val nbIndexingIndices = indexingIndices.length + + // this error is checked by checkStarPatOK + // if(isSeq) assert(firstIndexingBinder + nbIndexingIndices + (if(lastIsStar) 1 else 0) == nbSubPats, "(resultInMonad, ts, subPatTypes, subPats)= "+(resultInMonad, ts, subPatTypes, subPats)) + // there are `firstIndexingBinder` non-seq tuple elements preceding the Seq + (((1 to firstIndexingBinder) map tupleSel(binder)) ++ + // then we have to index the binder that represents the sequence for the remaining subpatterns, except for... + (indexingIndices map pmgen.index(seqTree(binder))) ++ + // the last one -- if the last subpattern is a sequence wildcard: drop the prefix (indexed by the refs on the line above), return the remainder + (if(!lastIsStar) Nil else List( + if(nbIndexingIndices == 0) seqTree(binder) + else pmgen.drop(seqTree(binder))(nbIndexingIndices)))).toList + } + + // the trees that select the subpatterns on the extractor's result, referenced by `binder` + // require (nbSubPats > 0 && (!lastIsStar || isSeq)) + protected def subPatRefs(binder: Symbol): List[Tree] = { + if (nbSubPats == 0) Nil + else if (isSeq) subPatRefsSeq(binder) + else ((1 to nbSubPats) map tupleSel(binder)).toList + } + + protected def lengthGuard(binder: Symbol): Option[Tree] = + // no need to check unless it's an unapplySeq and the minimal length is non-trivially satisfied + if (!isSeq || (expectedLength < minLenToCheck)) None + else { import CODE._ + // `binder.lengthCompare(expectedLength)` + def checkExpectedLength = (seqTree(binder) DOT seqLenCmp)(LIT(expectedLength)) + + // the comparison to perform + // when the last subpattern is a wildcard-star the expectedLength is but a lower bound + // (otherwise equality is required) + def compareOp: (Tree, Tree) => Tree = + if (lastIsStar) _ INT_>= _ + else _ INT_== _ + + // `if (binder != null && $checkExpectedLength [== | >=] 0) then else zero` + Some((seqTree(binder) ANY_!= NULL) AND compareOp(checkExpectedLength, ZERO)) + } + } + + // TODO: to be called when there's a def unapplyProd(x: T): Product_N + // for now only used for case classes -- pretending there's an unapplyProd that's the identity (and don't call it) + class ExtractorCallProd(fun: Tree, args: List[Tree]) extends ExtractorCall(args) { + // TODO: fix the illegal type bound in pos/t602 -- type inference messes up before we get here: + /*override def equals(x$1: Any): Boolean = ... + val o5: Option[com.mosol.sl.Span[Any]] = // Span[Any] --> Any is not a legal type argument for Span! + */ + // private val orig = fun match {case tpt: TypeTree => tpt.original case _ => fun} + // private val origExtractorTp = unapplyMember(orig.symbol.filter(sym => reallyExists(unapplyMember(sym.tpe))).tpe).tpe + // private val extractorTp = if (wellKinded(fun.tpe)) fun.tpe else existentialAbstraction(origExtractorTp.typeParams, origExtractorTp.resultType) + // println("ExtractorCallProd: "+ (fun.tpe, existentialAbstraction(origExtractorTp.typeParams, origExtractorTp.resultType))) + // println("ExtractorCallProd: "+ (fun.tpe, args map (_.tpe))) + private def extractorTp = fun.tpe + + def isTyped = fun.isTyped + + // to which type should the previous binder be casted? + def paramType = extractorTp.finalResultType + + def isSeq: Boolean = rawSubPatTypes.nonEmpty && isRepeatedParamType(rawSubPatTypes.last) + protected def rawSubPatTypes = extractorTp.paramTypes + + // binder has type paramType + def treeMaker(binder: Symbol, pos: Position): TreeMaker = { + // checks binder ne null before chaining to the next extractor + ProductExtractorTreeMaker(binder, lengthGuard(binder), Substitution(subPatBinders, subPatRefs(binder))) + } + +/* TODO: remove special case when the following bug is fixed +scala> :paste +// Entering paste mode (ctrl-D to finish) + +class Foo(x: Other) { x._1 } // BUG: can't refer to _1 if its defining class has not been type checked yet +case class Other(y: String) + +// Exiting paste mode, now interpreting. + +<console>:8: error: value _1 is not a member of Other + class Foo(x: Other) { x._1 } + ^ + +scala> case class Other(y: String) +defined class Other + +scala> class Foo(x: Other) { x._1 } +defined class Foo */ + override protected def tupleSel(binder: Symbol)(i: Int): Tree = { import CODE._ + // reference the (i-1)th case accessor if it exists, otherwise the (i-1)th tuple component + val caseAccs = binder.info.typeSymbol.caseFieldAccessors + if (caseAccs isDefinedAt (i-1)) REF(binder) DOT caseAccs(i-1) + else pmgen.tupleSel(binder)(i) + } + + override def toString(): String = "case class "+ (if (extractorTp eq null) fun else paramType.typeSymbol) +" with arguments "+ args + } + + class ExtractorCallRegular(extractorCallIncludingDummy: Tree, args: List[Tree]) extends ExtractorCall(args) { + private lazy val Some(Apply(extractorCall, _)) = extractorCallIncludingDummy.find{ case Apply(_, List(Ident(nme.SELECTOR_DUMMY))) => true case _ => false } + + def tpe = extractorCall.tpe + def isTyped = (tpe ne NoType) && extractorCall.isTyped && (resultInMonad ne ErrorType) + def paramType = tpe.paramTypes.head + def resultType = tpe.finalResultType + def isSeq = extractorCall.symbol.name == nme.unapplySeq + def treeMaker(patBinderOrCasted: Symbol, pos: Position): TreeMaker = { // the extractor call (applied to the binder bound by the flatMap corresponding to the previous (i.e., enclosing/outer) pattern) val extractorApply = atPos(pos)(spliceApply(patBinderOrCasted)) @@ -399,14 +510,22 @@ trait PatMatVirtualiser extends ast.TreeDSL { self: Analyzer => else extractorApply val binder = freshSym(pos, resultInMonad) // can't simplify this when subPatBinders.isEmpty, since UnitClass.tpe is definitely wrong when isSeq, and resultInMonad should always be correct since it comes directly from the extractor's result type - val subpatRefs = if (subPatBinders isEmpty) Nil else subPatRefs(binder) - lengthGuard(binder) match { - case None => ExtractorTreeMaker(patTreeLifted, binder, Substitution(subPatBinders, subpatRefs)) - case Some(lenGuard) => FilteredExtractorTreeMaker(patTreeLifted, lenGuard, binder, Substitution(subPatBinders, subpatRefs)) + case None => ExtractorTreeMaker(patTreeLifted, binder, Substitution(subPatBinders, subPatRefs(binder))) + case Some(lenGuard) => FilteredExtractorTreeMaker(patTreeLifted, lenGuard, binder, Substitution(subPatBinders, subPatRefs(binder))) } } + override protected def seqTree(binder: Symbol): Tree = + if (firstIndexingBinder == 0) CODE.REF(binder) + else super.seqTree(binder) + + // the trees that select the subpatterns on the extractor's result, referenced by `binder` + // require (nbSubPats > 0 && (!lastIsStar || isSeq)) + override protected def subPatRefs(binder: Symbol): List[Tree] = + if (!isSeq && nbSubPats == 1) List(CODE.REF(binder)) // special case for extractors + else super.subPatRefs(binder) + protected def spliceApply(binder: Symbol): Tree = { object splice extends Transformer { override def transform(t: Tree) = t match { @@ -418,7 +537,19 @@ trait PatMatVirtualiser extends ast.TreeDSL { self: Analyzer => splice.transform(extractorCallIncludingDummy) } - private lazy val rawSubPatTypes = + // what's the extractor's result type in the monad? + // turn an extractor's result type into something `monadTypeToSubPatTypesAndRefs` understands + protected lazy val resultInMonad: Type = if(!hasLength(tpe.paramTypes, 1)) ErrorType else { + if (resultType.typeSymbol == BooleanClass) UnitClass.tpe + else { + val monadArgs = resultType.baseType(matchingMonadType.typeSymbol).typeArgs + // assert(monadArgs.length == 1, "unhandled extractor type: "+ extractorTp) // TODO: overloaded unapply?? + if(monadArgs.length == 1) monadArgs(0) + else ErrorType + } + } + + protected lazy val rawSubPatTypes = if (resultInMonad.typeSymbol eq UnitClass) Nil else if(nbSubPats == 1) List(resultInMonad) else getProductArgs(resultInMonad) match { @@ -426,56 +557,6 @@ trait PatMatVirtualiser extends ast.TreeDSL { self: Analyzer => case x => x } - private def seqLenCmp = rawSubPatTypes.last member nme.lengthCompare - private def seqTp = rawSubPatTypes.last baseType SeqClass - private lazy val firstIndexingBinder = rawSubPatTypes.length - 1 // rawSubPatTypes.last is the Seq, thus there are `rawSubPatTypes.length - 1` non-seq elements in the tuple - private lazy val lastIndexingBinder = if(lastIsStar) nbSubPats-2 else nbSubPats-1 - private lazy val expectedLength = lastIndexingBinder - firstIndexingBinder + 1 - private lazy val minLenToCheck = if(lastIsStar) 1 else 0 - private def seqTree(binder: Symbol) = if(firstIndexingBinder == 0) CODE.REF(binder) else pmgen.tupleSel(binder)(firstIndexingBinder+1) - - // the trees that select the subpatterns on the extractor's result, referenced by `binder` - // require (nbSubPats > 0 && (!lastIsStar || isSeq)) - private def subPatRefs(binder: Symbol): List[Tree] = { - // only relevant if isSeq: (here to avoid capturing too much in the returned closure) - val indexingIndices = (0 to (lastIndexingBinder-firstIndexingBinder)) - val nbIndexingIndices = indexingIndices.length - - // this error is checked by checkStarPatOK - // if(isSeq) assert(firstIndexingBinder + nbIndexingIndices + (if(lastIsStar) 1 else 0) == nbSubPats, "(resultInMonad, ts, subPatTypes, subPats)= "+(resultInMonad, ts, subPatTypes, subPats)) - - (if(isSeq) { - // there are `firstIndexingBinder` non-seq tuple elements preceding the Seq - ((1 to firstIndexingBinder) map pmgen.tupleSel(binder)) ++ - // then we have to index the binder that represents the sequence for the remaining subpatterns, except for... - (indexingIndices map pmgen.index(seqTree(binder))) ++ - // the last one -- if the last subpattern is a sequence wildcard: drop the prefix (indexed by the refs on the line above), return the remainder - (if(!lastIsStar) Nil else List( - if(nbIndexingIndices == 0) seqTree(binder) - else pmgen.drop(seqTree(binder))(nbIndexingIndices))) - } - else if(nbSubPats == 1) List(CODE.REF(binder)) - else ((1 to nbSubPats) map pmgen.tupleSel(binder))).toList - } - - private def lengthGuard(binder: Symbol): Option[Tree] = - // no need to check unless it's an unapplySeq and the minimal length is non-trivially satisfied - if (!isSeq || (expectedLength < minLenToCheck)) None - else { import CODE._ - // `binder.lengthCompare(expectedLength)` - def checkExpectedLength = (seqTree(binder) DOT seqLenCmp)(LIT(expectedLength)) - - // the comparison to perform - // when the last subpattern is a wildcard-star the expectedLength is but a lower bound - // (otherwise equality is required) - def compareOp: (Tree, Tree) => Tree = - if (lastIsStar) _ INT_>= _ - else _ INT_== _ - - // `if (binder != null && $checkExpectedLength [== | >=] 0) then else zero` - Some((seqTree(binder) ANY_!= NULL) AND compareOp(checkExpectedLength, ZERO)) - } - override def toString() = extractorCall +": "+ extractorCall.tpe +" (symbol= "+ extractorCall.symbol +")." } @@ -638,6 +719,17 @@ trait PatMatVirtualiser extends ast.TreeDSL { self: Analyzer => */ case class ExtractorTreeMaker(extractor: Tree, nextBinder: Symbol, initialSubstitution: Substitution) extends SingleExtractorTreeMaker + case class ProductExtractorTreeMaker(prevBinder: Symbol, extraCond: Option[Tree], initialSubstitution: Substitution) extends TreeMaker { import CODE._ + def chainBefore(next: Tree): Tree = { + val nullCheck = REF(prevBinder) OBJ_NE NULL + val cond = extraCond match { + case None => nullCheck + case Some(c) => nullCheck AND c + } + pmgen.condOptimized(cond, substitution(next)) + } + } + case class FilteredExtractorTreeMaker(extractor: Tree, guard: Tree, nextBinder: Symbol, initialSubstitution: Substitution) extends FunTreeMaker { def chainBefore(next: Tree): Tree = pmgen.flatMap(extractor, wrapFunSubst(pmgen.condOptimized(guard, next))) setPos extractor.pos @@ -763,6 +855,7 @@ trait PatMatVirtualiser extends ast.TreeDSL { self: Analyzer => def typedOrElse(pt: Type)(thisCase: Tree, elseCase: Tree): Tree def guard(c: Tree): Tree def zero: Tree + def one(res: Tree): Tree // TODO: defaults in traits + self types == broken? // def guard(c: Tree, then: Tree, tp: Type): Tree // def cond(c: Tree): Tree = cond(c, UNIT, NoType) @@ -875,6 +968,7 @@ trait PatMatVirtualiser extends ast.TreeDSL { self: Analyzer => // methods in MatchingStrategy (the monad companion) -- used directly in translation def runOrElse(scrut: Tree, matcher: Tree, scrutTp: Type, resTp: Type): Tree = genTypeApply(matchingStrategy DOT vpmName.runOrElse, scrutTp, resTp) APPLY (scrut) APPLY (matcher) // matchingStrategy.runOrElse(scrut)(matcher) def zero: Tree = matchingStrategy DOT vpmName.zero // matchingStrategy.zero + def one(res: Tree): Tree = one(res, NoType) def one(res: Tree, tp: Type = NoType, oneName: Name = vpmName.one): Tree = genTypeApply(matchingStrategy DOT oneName, tp) APPLY (res) // matchingStrategy.one(res) def or(f: Tree, as: List[Tree]): Tree = (matchingStrategy DOT vpmName.or)((f :: as): _*) // matchingStrategy.or(f, as) def guard(c: Tree): Tree = (matchingStrategy DOT vpmName.guard)(c, UNIT) // matchingStrategy.guard(c, then) -- a user-defined guard |