summaryrefslogtreecommitdiff
path: root/src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala')
-rw-r--r--src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala60
1 files changed, 40 insertions, 20 deletions
diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala
index 3c35149636..15adfa7d82 100644
--- a/src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala
+++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala
@@ -84,6 +84,9 @@ abstract class SelectiveCPSTransform extends PluginComponent with
log("found shift: " + tree)
atPos(tree.pos) {
val funR = gen.mkAttributedRef(MethShiftR) // TODO: correct?
+ //gen.mkAttributedSelect(gen.mkAttributedSelect(gen.mkAttributedSelect(gen.mkAttributedIdent(ScalaPackage),
+ //ScalaPackage.tpe.member("util")), ScalaPackage.tpe.member("util").tpe.member("continuations")), MethShiftR)
+ //gen.mkAttributedRef(ModCPS.tpe, MethShiftR) // TODO: correct?
log(funR.tpe)
Apply(
TypeApply(funR, targs).setType(appliedType(funR.tpe, targs.map((t:Tree) => t.tpe))),
@@ -183,21 +186,32 @@ abstract class SelectiveCPSTransform extends PluginComponent with
case expr => (Nil, expr)
}
- val expr2 = if (catches.nonEmpty) {
- val pos = catches.head.pos
- val arg = currentOwner.newValueParameter(pos, "$ex").setInfo(ThrowableClass.tpe)
- val catches2 = catches1 map (duplicateTree(_).asInstanceOf[CaseDef])
- val rhs = Match(Ident(arg), catches2)
- val fun = Function(List(ValDef(arg)), rhs)
- val expr2 = localTyper.typed(atPos(pos) { Apply(Select(expr1, expr1.tpe.member("flatMapCatch")), List(fun)) })
+ val targettp = transformCPSType(tree.tpe)
- arg.owner = fun.symbol
+// val expr2 = if (catches.nonEmpty) {
+ val pos = catches.head.pos
+ val argSym = currentOwner.newValueParameter(pos, "$ex").setInfo(ThrowableClass.tpe)
+ val rhs = Match(Ident(argSym), catches1)
+ val fun = Function(List(ValDef(argSym)), rhs)
+ val funSym = currentOwner.newValueParameter(pos, "$catches").setInfo(appliedType(PartialFunctionClass.tpe, List(ThrowableClass.tpe, targettp)))
+ val funDef = localTyper.typed(atPos(pos) { ValDef(funSym, fun) })
+ val expr2 = localTyper.typed(atPos(pos) { Apply(Select(expr1, expr1.tpe.member("flatMapCatch")), List(Ident(funSym))) })
+
+ argSym.owner = fun.symbol
val chown = new ChangeOwnerTraverser(currentOwner, fun.symbol)
chown.traverse(rhs)
- expr2
- } else
- expr1
+ val exSym = currentOwner.newValueParameter(pos, "$ex").setInfo(ThrowableClass.tpe)
+ val catch2 = { localTyper.typedCases(tree, List(
+ CaseDef(Bind(exSym, Typed(Ident("_"), TypeTree(ThrowableClass.tpe))),
+ Apply(Select(Ident(funSym), "isDefinedAt"), List(Ident(exSym))),
+ Apply(Ident(funSym), List(Ident(exSym))))
+ ), ThrowableClass.tpe, targettp) }
+
+ //typedCases(tree, catches, ThrowableClass.tpe, pt)
+
+ localTyper.typed(Block(List(funDef), treeCopy.Try(tree, treeCopy.Block(block1, stms, expr2), catch2, finalizer1)))
+
/*
disabled for now - see notes above
@@ -215,9 +229,6 @@ abstract class SelectiveCPSTransform extends PluginComponent with
} else
expr2
*/
-
-
- treeCopy.Try(tree, treeCopy.Block(block1, stms, expr2), catches1, finalizer1)
} else {
treeCopy.Try(tree, block1, catches1, finalizer1)
}
@@ -249,7 +260,9 @@ abstract class SelectiveCPSTransform extends PluginComponent with
log("found marked ValDef "+name+" of type " + vd.symbol.tpe)
val tpe = vd.symbol.tpe
- val rhs1 = transform(rhs)
+ val rhs1 = atOwner(vd.symbol) { transform(rhs) }
+
+ new ChangeOwnerTraverser(vd.symbol, currentOwner).traverse(rhs1) // TODO: don't traverse twice
log("valdef symbol " + vd.symbol + " has type " + tpe)
log("right hand side " + rhs1 + " has type " + rhs1.tpe)
@@ -257,8 +270,10 @@ abstract class SelectiveCPSTransform extends PluginComponent with
log("currentOwner: " + currentOwner)
log("currentMethod: " + currentMethod)
-
val (bodyStms, bodyExpr) = transBlock(rest, expr)
+ // FIXME: result will later be traversed again by TreeSymSubstituter and
+ // ChangeOwnerTraverser => exp. running time.
+ // Should be changed to fuse traversals into one.
val specialCaseTrivial = bodyExpr match {
case Apply(fun, args) =>
@@ -291,6 +306,8 @@ abstract class SelectiveCPSTransform extends PluginComponent with
arg.owner = fun.symbol
new ChangeOwnerTraverser(currentOwner, fun.symbol).traverse(body)
+ // see note about multiple traversals above
+
log("fun.symbol: "+fun.symbol)
log("fun.symbol.owner: "+fun.symbol.owner)
log("arg.owner: "+arg.owner)
@@ -314,6 +331,8 @@ abstract class SelectiveCPSTransform extends PluginComponent with
})
}
+ def mkBlock(stms: List[Tree], expr: Tree) = if (stms.nonEmpty) Block(stms, expr) else expr
+
try {
if (specialCaseTrivial) {
log("will optimize possible tail call: " + bodyExpr)
@@ -332,9 +351,9 @@ abstract class SelectiveCPSTransform extends PluginComponent with
val argSym = currentOwner.newValue(vd.symbol.name).setInfo(tpe)
val argDef = localTyper.typed(ValDef(argSym, Select(ctxRef, ctxRef.tpe.member("getTrivialValue"))))
val switchExpr = localTyper.typed(atPos(vd.symbol.pos) {
- val body2 = duplicateTree(Block(bodyStms, bodyExpr)) // dup before typing!
+ val body2 = duplicateTree(mkBlock(bodyStms, bodyExpr)) // dup before typing!
If(Select(ctxRef, ctxSym.tpe.member("isTrivial")),
- applyTrivial(argSym, Block(argDef::bodyStms, bodyExpr)),
+ applyTrivial(argSym, mkBlock(argDef::bodyStms, bodyExpr)),
applyCombinatorFun(ctxRef, body2))
})
(List(ctxDef), switchExpr)
@@ -342,7 +361,7 @@ abstract class SelectiveCPSTransform extends PluginComponent with
// ctx.flatMap { <lhs> => ... }
// or
// ctx.map { <lhs> => ... }
- (Nil, applyCombinatorFun(rhs1, Block(bodyStms, bodyExpr)))
+ (Nil, applyCombinatorFun(rhs1, mkBlock(bodyStms, bodyExpr)))
}
} catch {
case ex:TypeError =>
@@ -351,8 +370,9 @@ abstract class SelectiveCPSTransform extends PluginComponent with
}
case _ =>
+ val stm1 = transform(stm)
val (a, b) = transBlock(rest, expr)
- (transform(stm)::a, b)
+ (stm1::a, b)
}
}
}