aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-07-19 22:42:44 -0700
committerReynold Xin <rxin@databricks.com>2015-07-19 22:42:44 -0700
commit930253e0766a7585347edfb73ed11b1bf78143fe (patch)
tree4864c0b4ca9cf9ece0aed638ce2d3e0e8ceaa330
parentd743bec645fd2a65bd488d2d660b3aa2135b4da6 (diff)
downloadspark-930253e0766a7585347edfb73ed11b1bf78143fe.tar.gz
spark-930253e0766a7585347edfb73ed11b1bf78143fe.tar.bz2
spark-930253e0766a7585347edfb73ed11b1bf78143fe.zip
[SPARK-9185][SQL] improve code gen for mutable states to support complex initialization
Sometimes we need more than one step to initialize the mutable states in code gen like https://github.com/apache/spark/pull/7516 Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7521 from cloud-fan/init and squashes the following commits: 2106445 [Wenchen Fan] improve code gen for mutable states
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala3
9 files changed, 42 insertions, 38 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 7c388bc346..b2468b6a18 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -60,13 +60,19 @@ class CodeGenContext {
/**
* Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a
* 3-tuple: java type, variable name, code to init it.
+ * As an example, ("int", "count", "count = 0;") will produce code:
+ * {{{
+ * private int count;
+ * count = 0;
+ * }}}
+ *
* They will be kept as member variables in generated classes like `SpecificProjection`.
*/
val mutableStates: mutable.ArrayBuffer[(String, String, String)] =
mutable.ArrayBuffer.empty[(String, String, String)]
- def addMutableState(javaType: String, variableName: String, initialValue: String): Unit = {
- mutableStates += ((javaType, variableName, initialValue))
+ def addMutableState(javaType: String, variableName: String, initialCode: String): Unit = {
+ mutableStates += ((javaType, variableName, initialCode))
}
final val intervalType: String = classOf[Interval].getName
@@ -234,6 +240,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected val mutableRowType: String = classOf[MutableRow].getName
protected val genericMutableRowType: String = classOf[GenericMutableRow].getName
+ protected def declareMutableStates(ctx: CodeGenContext) = {
+ ctx.mutableStates.map { case (javaType, variableName, _) =>
+ s"private $javaType $variableName;"
+ }.mkString("\n ")
+ }
+
+ protected def initMutableStates(ctx: CodeGenContext) = {
+ ctx.mutableStates.map(_._3).mkString("\n ")
+ }
+
/**
* Generates a class for a given input expression. Called when there is not cached code
* already available.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index b82bd6814b..03b4b3c216 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -78,10 +78,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
}
}
- val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) =>
- s"private $javaType $variableName = $initialValue;"
- }.mkString("\n ")
-
val code = s"""
public Object generate($exprType[] expr) {
return new SpecificProjection(expr);
@@ -89,13 +85,14 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
class SpecificProjection extends ${classOf[BaseMutableProjection].getName} {
- private $exprType[] expressions = null;
- private $mutableRowType mutableRow = null;
- $mutableStates
+ private $exprType[] expressions;
+ private $mutableRowType mutableRow;
+ ${declareMutableStates(ctx)}
public SpecificProjection($exprType[] expr) {
expressions = expr;
mutableRow = new $genericMutableRowType(${expressions.size});
+ ${initMutableStates(ctx)}
}
public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index 856ff9f1f9..2e6f9e204d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -84,9 +84,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
}
"""
}.mkString("\n")
- val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) =>
- s"private $javaType $variableName = $initialValue;"
- }.mkString("\n ")
val code = s"""
public SpecificOrdering generate($exprType[] expr) {
return new SpecificOrdering(expr);
@@ -94,11 +91,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
class SpecificOrdering extends ${classOf[BaseOrdering].getName} {
- private $exprType[] expressions = null;
- $mutableStates
+ private $exprType[] expressions;
+ ${declareMutableStates(ctx)}
public SpecificOrdering($exprType[] expr) {
expressions = expr;
+ ${initMutableStates(ctx)}
}
@Override
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
index 9e5a745d51..1dda5992c3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -40,9 +40,6 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
protected def create(predicate: Expression): ((InternalRow) => Boolean) = {
val ctx = newCodeGenContext()
val eval = predicate.gen(ctx)
- val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) =>
- s"private $javaType $variableName = $initialValue;"
- }.mkString("\n ")
val code = s"""
public SpecificPredicate generate($exprType[] expr) {
return new SpecificPredicate(expr);
@@ -50,9 +47,10 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
class SpecificPredicate extends ${classOf[Predicate].getName} {
private final $exprType[] expressions;
- $mutableStates
+ ${declareMutableStates(ctx)}
public SpecificPredicate($exprType[] expr) {
expressions = expr;
+ ${initMutableStates(ctx)}
}
@Override
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 8f9fcbf810..405d6b0e3b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -151,21 +151,18 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
s"""if (!nullBits[$i]) arr[$i] = c$i;"""
}.mkString("\n ")
- val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) =>
- s"private $javaType $variableName = $initialValue;"
- }.mkString("\n ")
-
val code = s"""
public SpecificProjection generate($exprType[] expr) {
return new SpecificProjection(expr);
}
class SpecificProjection extends ${classOf[BaseProjection].getName} {
- private $exprType[] expressions = null;
- $mutableStates
+ private $exprType[] expressions;
+ ${declareMutableStates(ctx)}
public SpecificProjection($exprType[] expr) {
expressions = expr;
+ ${initMutableStates(ctx)}
}
@Override
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index a81d545a8e..3a8e8302b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -74,10 +74,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
}"""
}.mkString("\n ")
- val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) =>
- s"private $javaType $variableName = $initialValue;"
- }.mkString("\n ")
-
val code = s"""
private $exprType[] expressions;
@@ -90,10 +86,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
private UnsafeRow target = new UnsafeRow();
private byte[] buffer = new byte[64];
+ ${declareMutableStates(ctx)}
- $mutableStates
-
- public SpecificProjection() {}
+ public SpecificProjection() {
+ ${initMutableStates(ctx)}
+ }
// Scala.Function1 need this
public Object apply(Object row) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
index 65093dc722..822898e561 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
@@ -60,9 +60,9 @@ case class Rand(seed: Long) extends RDG(seed) {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val rngTerm = ctx.freshName("rng")
- val className = classOf[XORShiftRandom].getCanonicalName
+ val className = classOf[XORShiftRandom].getName
ctx.addMutableState(className, rngTerm,
- s"new $className($seed + org.apache.spark.TaskContext.getPartitionId())")
+ s"$rngTerm = new $className($seed + org.apache.spark.TaskContext.getPartitionId());")
ev.isNull = "false"
s"""
final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextDouble();
@@ -83,9 +83,9 @@ case class Randn(seed: Long) extends RDG(seed) {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val rngTerm = ctx.freshName("rng")
- val className = classOf[XORShiftRandom].getCanonicalName
+ val className = classOf[XORShiftRandom].getName
ctx.addMutableState(className, rngTerm,
- s"new $className($seed + org.apache.spark.TaskContext.getPartitionId())")
+ s"$rngTerm = new $className($seed + org.apache.spark.TaskContext.getPartitionId());")
ev.isNull = "false"
s"""
final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextGaussian();
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
index fec403fe2d..4d8ed08973 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
@@ -58,9 +58,9 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val countTerm = ctx.freshName("count")
val partitionMaskTerm = ctx.freshName("partitionMask")
- ctx.addMutableState(ctx.JAVA_LONG, countTerm, "0L")
+ ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;")
ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm,
- "((long) org.apache.spark.TaskContext.getPartitionId()) << 33")
+ s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;")
ev.isNull = "false"
s"""
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
index 7c790c549a..43ffc9cc84 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
@@ -41,7 +41,8 @@ private[sql] case object SparkPartitionID extends LeafExpression {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val idTerm = ctx.freshName("partitionId")
- ctx.addMutableState(ctx.JAVA_INT, idTerm, "org.apache.spark.TaskContext.getPartitionId()")
+ ctx.addMutableState(ctx.JAVA_INT, idTerm,
+ s"$idTerm = org.apache.spark.TaskContext.getPartitionId();")
ev.isNull = "false"
s"final ${ctx.javaType(dataType)} ${ev.primitive} = $idTerm;"
}