diff options
author | Paul Phillips <paulp@improving.org> | 2009-06-30 21:39:16 +0000 |
---|---|---|
committer | Paul Phillips <paulp@improving.org> | 2009-06-30 21:39:16 +0000 |
commit | ab099645c9dfc907b800c42dc36033e9ae1a1e05 (patch) | |
tree | a56a5fdedf6043b1fe84aa37d937bd0beda5d82f | |
parent | d6519af64cab257fd45d12b818b2117e9c0f5440 (diff) | |
download | scala-ab099645c9dfc907b800c42dc36033e9ae1a1e05.tar.gz scala-ab099645c9dfc907b800c42dc36033e9ae1a1e05.tar.bz2 scala-ab099645c9dfc907b800c42dc36033e9ae1a1e05.zip |
Mostly rewriting Unapplies as I work my way thr...
Mostly rewriting Unapplies as I work my way through all the pattern
matcher related code.
3 files changed, 113 insertions, 119 deletions
diff --git a/src/compiler/scala/tools/nsc/ast/TreeDSL.scala b/src/compiler/scala/tools/nsc/ast/TreeDSL.scala index 3dce8148db..96da61a7fd 100644 --- a/src/compiler/scala/tools/nsc/ast/TreeDSL.scala +++ b/src/compiler/scala/tools/nsc/ast/TreeDSL.scala @@ -16,6 +16,7 @@ trait TreeDSL { import global._ import definitions._ + import gen.{ scalaDot } object CODE { object LIT extends (Any => Literal) { @@ -182,12 +183,25 @@ trait TreeDSL { def BLOCK(xs: Tree*) = Block(xs.init.toList, xs.last) def NOT(tree: Tree) = Select(tree, getMember(BooleanClass, nme.UNARY_!)) - // - // Unused, from the pattern matcher: - // def SEQELEM(tpe: Type): Type = (tpe.widen baseType SeqClass) match { - // case NoType => Predef.error("arg " + tpe + " not subtype of Seq[A]") - // case t => t typeArgs 0 - // } + private val _SOME = scalaDot(nme.Some) + def SOME(xs: Tree*) = Apply(_SOME, List(makeTupleTerm(xs.toList, true))) + + /** Some of this is basically verbatim from TreeBuilder, but we do not want + * to get involved with him because he's an untyped only sort. + */ + private def tupleName(count: Int, f: (String) => Name = newTermName(_: String)) = + scalaDot(f("Tuple" + count)) + + def makeTupleTerm(trees: List[Tree], flattenUnary: Boolean): Tree = trees match { + case Nil => UNIT + case List(tree) if flattenUnary => tree + case _ => Apply(tupleName(trees.length), trees) + } + def makeTupleType(trees: List[Tree], flattenUnary: Boolean): Tree = trees match { + case Nil => gen.scalaUnitConstr + case List(tree) if flattenUnary => tree + case _ => AppliedTypeTree(tupleName(trees.length, newTypeName), trees) + } /** Implicits - some of these should probably disappear **/ implicit def mkTreeMethods(target: Tree): TreeMethods = new TreeMethods(target) diff --git a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala index 3e265c9bfb..e0577a9b0c 100644 --- a/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala +++ b/src/compiler/scala/tools/nsc/matching/ParallelMatching.scala @@ -377,13 +377,13 @@ trait ParallelMatching extends ast.TreeDSL { final def tree(implicit theOwner: Symbol, failTree: Tree) = { val (uacall, vdefs, srep, frep) = this.getTransition val succ = srep.toTree - val fail = frep.map(_.toTree) getOrElse failTree + val fail = frep map (_.toTree) getOrElse (failTree) val cond = - if (uacall.symbol.tpe.isBoolean) typer.typed(ID(uacall.symbol)) + if (uacall.symbol.tpe.isBoolean) typer typed ID(uacall.symbol) else uacall.symbol IS_DEFINED typer typed (squeezedBlock( - List(rep.handleOuter(uacall)), + List(rep handleOuter uacall), IF(cond) THEN squeezedBlock(vdefs, succ) ELSE fail )) } diff --git a/src/compiler/scala/tools/nsc/typechecker/Unapplies.scala b/src/compiler/scala/tools/nsc/typechecker/Unapplies.scala index 2b8837c084..e499376068 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Unapplies.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Unapplies.scala @@ -12,10 +12,18 @@ import symtab.Flags._ * @author Martin Odersky * @version 1.0 */ -trait Unapplies { self: Analyzer => +trait Unapplies extends ast.TreeDSL +{ + self: Analyzer => import global._ import definitions._ + import CODE.{ CASE => _, _ } + + private def isVarargs(vd: ValDef) = treeInfo isRepeatedParamType vd.tpt + private def isByName(vd: ValDef) = treeInfo isByNameParamType vd.tpt + private def toIdent(x: DefTree) = Ident(x.name) + private def applyAndReturn[T](f: T => Unit)(x: T): T = { f(x) ; x } /** returns type list for return type of the extraction */ def unapplyTypeList(ufn: Symbol, ufntpe: Type) = { @@ -24,7 +32,7 @@ trait Unapplies { self: Analyzer => ufn.name match { case nme.unapply => unapplyTypeListFromReturnType(ufntpe) case nme.unapplySeq => unapplyTypeListFromReturnTypeSeq(ufntpe) - case _ => throw new TypeError(ufn+" is not an unapply or unapplySeq") + case _ => throw new TypeError(ufn+" is not an unapply or unapplySeq") } } /** (the inverse of unapplyReturnTypeSeq) @@ -33,20 +41,18 @@ trait Unapplies { self: Analyzer => * - returns T0...Tn if n>0 and T <: Product[T0...Tn]] * - returns T otherwise */ - def unapplyTypeListFromReturnType(tp1: Type): List[Type] = { // rename: unapplyTypeListFromReturnType + def unapplyTypeListFromReturnType(tp1: Type): List[Type] = { val tp = unapplyUnwrap(tp1) - val B = BooleanClass - val O = OptionClass - val S = SomeClass - tp.typeSymbol match { // unapplySeqResultToMethodSig - case B => Nil - case O | S => + tp.typeSymbol match { // unapplySeqResultToMethodSig + case BooleanClass => Nil + case OptionClass | SomeClass => val prod = tp.typeArgs.head - getProductArgs(prod) match { - case Some(all @ (x1::x2::xs)) => all // n >= 2 - case _ => prod::Nil // special n == 0 || n == 1 + getProductArgs(prod) match { + case Some(xs) if xs.size > 1 => xs // n > 1 + case _ => List(prod) // special n == 0 || n == 1 } - case _ => throw new TypeError("result type "+tp+" of unapply not in {boolean, Option[_], Some[_]}") + case _ => + throw new TypeError("result type "+tp+" of unapply not in {Boolean, Option[_], Some[_]}") } } @@ -57,15 +63,16 @@ trait Unapplies { self: Analyzer => */ def unapplyTypeListFromReturnTypeSeq(tp1: Type): List[Type] = { val tp = unapplyUnwrap(tp1) - val O = OptionClass; val S = SomeClass; tp.typeSymbol match { - case O | S => - val ts = unapplyTypeListFromReturnType(tp1) - val last1 = ts.last.baseType(SeqClass) match { - case TypeRef(pre, seqClass, args) => typeRef(pre, RepeatedParamClass, args) - case _ => throw new TypeError("last not seq") - } - ts.init ::: List(last1) - case _ => throw new TypeError("result type "+tp+" of unapply not in {Option[_], Some[_]}") + tp.typeSymbol match { + case OptionClass | SomeClass => + val ts = unapplyTypeListFromReturnType(tp1) + val last1 = (ts.last baseType SeqClass) match { + case TypeRef(pre, seqClass, args) => typeRef(pre, RepeatedParamClass, args) // XXX seqClass or SeqClass? + case _ => throw new TypeError("last not seq") + } + ts.init ::: List(last1) + case _ => + throw new TypeError("result type "+tp+" of unapply not in {Option[_], Some[_]}") } } @@ -73,50 +80,36 @@ trait Unapplies { self: Analyzer => * for n == 0, boolean * for n == 1, Some[T0] * else Some[Product[Ti]] - def unapplyReturnType(elems: List[Type], useWildCards: Boolean) = - if (elems.isEmpty) - BooleanClass.tpe - else if (elems.length == 1) - optionType(if(useWildCards) WildcardType else elems(0)) - else - productType({val es = elems; if(useWildCards) elems map { x => WildcardType} else elems}) */ def unapplyReturnTypeExpected(argsLength: Int) = argsLength match { case 0 => BooleanClass.tpe case 1 => optionType(WildcardType) - case n => optionType(productType(List.range(0,n).map (arg => WildcardType))) + case n => optionType(productType((List fill n)(WildcardType))) } /** returns unapply or unapplySeq if available */ - def unapplyMember(tp: Type): Symbol = { - var unapp = tp.member(nme.unapply) - if (unapp == NoSymbol) unapp = tp.member(nme.unapplySeq) - unapp + def unapplyMember(tp: Type): Symbol = (tp member nme.unapply) match { + case NoSymbol => tp member nme.unapplySeq + case unapp => unapp } - def copyUntyped[T <: Tree](tree: T): T = { - val tree1 = tree.syntheticDuplicate - UnTyper.traverse(tree1) - tree1 - } + def copyUntyped[T <: Tree](tree: T): T = + applyAndReturn[T](UnTyper traverse _)(tree.syntheticDuplicate) - def copyUntypedInvariant(td: TypeDef): TypeDef = { - val tree1 = treeCopy.TypeDef(td, td.mods &~ (COVARIANT | CONTRAVARIANT), - td.name, td.tparams map (_.syntheticDuplicate), td.rhs.syntheticDuplicate) - UnTyper.traverse(tree1) - tree1 - } + def copyUntypedInvariant(td: TypeDef): TypeDef = + applyAndReturn[TypeDef](UnTyper traverse _)( + treeCopy.TypeDef(td, td.mods &~ (COVARIANT | CONTRAVARIANT), td.name, + td.tparams map (_.syntheticDuplicate), td.rhs.syntheticDuplicate) + ) private def classType(cdef: ClassDef, tparams: List[TypeDef]): Tree = { - val tycon = gen.mkAttributedRef(cdef.symbol) - if (tparams.isEmpty) tycon else AppliedTypeTree(tycon, tparams map (x => Ident(x.name))) + val tycon = REF(cdef.symbol) + if (tparams.isEmpty) tycon else AppliedTypeTree(tycon, tparams map toIdent) } private def constrParamss(cdef: ClassDef): List[List[ValDef]] = { - val constr = treeInfo.firstConstructor(cdef.impl.body) - (constr: @unchecked) match { - case DefDef(_, _, _, vparamss, _, _) => vparamss map (_ map copyUntyped[ValDef]) - } + val DefDef(_, _, _, vparamss, _, _) = treeInfo firstConstructor cdef.impl.body + vparamss map (_ map copyUntyped[ValDef]) } /** The return value of an unapply method of a case class C[Ts] @@ -124,95 +117,82 @@ trait Unapplies { self: Analyzer => * @param caseclazz The case class C[Ts] */ private def caseClassUnapplyReturnValue(param: Name, caseclazz: Symbol) = { - def caseFieldAccessorValue(selector: Symbol) = Select(Ident(param), selector) - val accessors = caseclazz.caseFieldAccessors - if (accessors.isEmpty) Literal(true) - else - Apply( - gen.scalaDot(nme.Some), - List( - if (accessors.tail.isEmpty) caseFieldAccessorValue(accessors.head) - else Apply( - gen.scalaDot(newTermName("Tuple" + accessors.length)), - accessors map caseFieldAccessorValue))) + def caseFieldAccessorValue(selector: Symbol): Tree = Ident(param) DOT selector + + caseclazz.caseFieldAccessors match { + case Nil => TRUE + case xs => SOME(xs map caseFieldAccessorValue: _*) + } } /** The module corresponding to a case class; without any member definitions */ - def caseModuleDef(cdef: ClassDef): ModuleDef = - companionModuleDef( - cdef, - if (!(cdef.mods hasFlag ABSTRACT) && cdef.tparams.isEmpty && constrParamss(cdef).length == 1) - List(gen.scalaFunctionConstr(constrParamss(cdef).head map (_.tpt), Ident(cdef.name)), - gen.scalaScalaObjectConstr) - else - List(gen.scalaScalaObjectConstr)) + def caseModuleDef(cdef: ClassDef): ModuleDef = { + def inheritFromFun1 = !(cdef.mods hasFlag ABSTRACT) && cdef.tparams.isEmpty && constrParamss(cdef).length == 1 + def createFun1 = gen.scalaFunctionConstr(constrParamss(cdef).head map (_.tpt), toIdent(cdef)) + def parents = if (inheritFromFun1) List(createFun1) else Nil + + companionModuleDef(cdef, parents ::: List(gen.scalaScalaObjectConstr)) + } def companionModuleDef(cdef: ClassDef, parents: List[Tree]): ModuleDef = atPos(cdef.pos) { ModuleDef( Modifiers(cdef.mods.flags & AccessFlags | SYNTHETIC, cdef.mods.privateWithin), cdef.name.toTermName, - Template(parents, emptyValDef, Modifiers(0), List(), List(List()), List())) + Template(parents, emptyValDef, NoMods, Nil, List(Nil), Nil)) } + private val caseMods = Modifiers(SYNTHETIC | CASE) + /** The apply method corresponding to a case class */ def caseModuleApplyMeth(cdef: ClassDef): DefDef = { - val tparams = cdef.tparams map copyUntypedInvariant - val cparamss = constrParamss(cdef) - atPos(cdef.pos) { - DefDef( - Modifiers(SYNTHETIC | CASE), - nme.apply, - tparams, - cparamss, - classType(cdef, tparams), + val tparams = cdef.tparams map copyUntypedInvariant + val cparamss = constrParamss(cdef) + atPos(cdef.pos)( + DefDef(caseMods, nme.apply, tparams, cparamss, classType(cdef, tparams), New(classType(cdef, tparams), cparamss map (_ map gen.paramToArg))) - } + ) } /** The unapply method corresponding to a case class */ def caseModuleUnapplyMeth(cdef: ClassDef): DefDef = { - val tparams = cdef.tparams map copyUntypedInvariant - val unapplyParamName = newTermName("x$0") - val hasVarArg = constrParamss(cdef) match { - case (cps @ (_ :: _)) :: _ => treeInfo.isRepeatedParamType(cps.last.tpt) - case _ => false - } - atPos(cdef.pos) { - DefDef( - Modifiers(SYNTHETIC | CASE), - if (hasVarArg) nme.unapplySeq else nme.unapply, - tparams, - List(List(ValDef(Modifiers(PARAM | SYNTHETIC), unapplyParamName, - classType(cdef, tparams), EmptyTree))), - TypeTree(), - caseClassUnapplyReturnValue(unapplyParamName, cdef.symbol)) + val tparams = cdef.tparams map copyUntypedInvariant + val paramName = newTermName("x$0") + val method = constrParamss(cdef) match { + case xs :: _ if !xs.isEmpty && isVarargs(xs.last) => nme.unapplySeq + case _ => nme.unapply } + val cparams = List(ValDef(Modifiers(PARAM | SYNTHETIC), paramName, classType(cdef, tparams), EmptyTree)) + + atPos(cdef.pos)( + DefDef(caseMods, method, tparams, List(cparams), TypeTree(), + caseClassUnapplyReturnValue(paramName, cdef.symbol)) + ) } def caseClassCopyMeth(cdef: ClassDef): Option[DefDef] = { - val cparamss = constrParamss(cdef) - if (cparamss.length == 1 && cparamss.head.isEmpty || // no copy method if there are no arguments - cdef.symbol.hasFlag(ABSTRACT) || - cparamss.exists(_.exists(vd => - treeInfo.isRepeatedParamType(vd.tpt) || - treeInfo.isByNameParamType(vd.tpt)))) - None + def isDisallowed(vd: ValDef) = isVarargs(vd) || isByName(vd) + val cparamss = constrParamss(cdef) + val flat = cparamss flatten + + if (flat.isEmpty || (cdef.symbol hasFlag ABSTRACT) || (flat exists isDisallowed)) None else { val tparams = cdef.tparams map copyUntypedInvariant // the parameter types have to be exactly the same as the constructor's parameter types; so it's // not good enough to just duplicated the (untyped) tpt tree; the parameter types are removed here // and re-added in ``finishWith'' in the namer. - val paramss = cparamss map (_.map(vd => - treeCopy.ValDef(vd, vd.mods | DEFAULTPARAM, vd.name, - TypeTree().setOriginal(vd.tpt), Ident(vd.name)))) - val classTpe = classType(cdef, tparams) - Some(atPos(cdef.pos) { + def paramWithDefault(vd: ValDef) = + treeCopy.ValDef(vd, vd.mods | DEFAULTPARAM, vd.name, TypeTree() setOriginal vd.tpt, toIdent(vd)) + + val paramss = cparamss map (_ map paramWithDefault) + val classTpe = classType(cdef, tparams) + + Some(atPos(cdef.pos)( DefDef(Modifiers(SYNTHETIC), nme.copy, tparams, paramss, classTpe, - New(classTpe, paramss map (_ map (p => Ident(p.name))))) - }) + New(classTpe, paramss map (_ map toIdent))) + )) } } } |