diff options
author | Wenchen Fan <cloud0fan@outlook.com> | 2015-07-19 22:42:44 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-07-19 22:42:44 -0700 |
commit | 930253e0766a7585347edfb73ed11b1bf78143fe (patch) | |
tree | 4864c0b4ca9cf9ece0aed638ce2d3e0e8ceaa330 | |
parent | d743bec645fd2a65bd488d2d660b3aa2135b4da6 (diff) | |
download | spark-930253e0766a7585347edfb73ed11b1bf78143fe.tar.gz spark-930253e0766a7585347edfb73ed11b1bf78143fe.tar.bz2 spark-930253e0766a7585347edfb73ed11b1bf78143fe.zip |
[SPARK-9185][SQL] improve code gen for mutable states to support complex initialization
Sometimes we need more than one step to initialize the mutable states in code gen like https://github.com/apache/spark/pull/7516
Author: Wenchen Fan <cloud0fan@outlook.com>
Closes #7521 from cloud-fan/init and squashes the following commits:
2106445 [Wenchen Fan] improve code gen for mutable states
9 files changed, 42 insertions, 38 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 7c388bc346..b2468b6a18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -60,13 +60,19 @@ class CodeGenContext { /** * Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a * 3-tuple: java type, variable name, code to init it. + * As an example, ("int", "count", "count = 0;") will produce code: + * {{{ + * private int count; + * count = 0; + * }}} + * * They will be kept as member variables in generated classes like `SpecificProjection`. */ val mutableStates: mutable.ArrayBuffer[(String, String, String)] = mutable.ArrayBuffer.empty[(String, String, String)] - def addMutableState(javaType: String, variableName: String, initialValue: String): Unit = { - mutableStates += ((javaType, variableName, initialValue)) + def addMutableState(javaType: String, variableName: String, initialCode: String): Unit = { + mutableStates += ((javaType, variableName, initialCode)) } final val intervalType: String = classOf[Interval].getName @@ -234,6 +240,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected val mutableRowType: String = classOf[MutableRow].getName protected val genericMutableRowType: String = classOf[GenericMutableRow].getName + protected def declareMutableStates(ctx: CodeGenContext) = { + ctx.mutableStates.map { case (javaType, variableName, _) => + s"private $javaType $variableName;" + }.mkString("\n ") + } + + protected def initMutableStates(ctx: CodeGenContext) = { + ctx.mutableStates.map(_._3).mkString("\n ") + } + /** * Generates a class for a given input expression. Called when there is not cached code * already available. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index b82bd6814b..03b4b3c216 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -78,10 +78,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu } } - val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => - s"private $javaType $variableName = $initialValue;" - }.mkString("\n ") - val code = s""" public Object generate($exprType[] expr) { return new SpecificProjection(expr); @@ -89,13 +85,14 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu class SpecificProjection extends ${classOf[BaseMutableProjection].getName} { - private $exprType[] expressions = null; - private $mutableRowType mutableRow = null; - $mutableStates + private $exprType[] expressions; + private $mutableRowType mutableRow; + ${declareMutableStates(ctx)} public SpecificProjection($exprType[] expr) { expressions = expr; mutableRow = new $genericMutableRowType(${expressions.size}); + ${initMutableStates(ctx)} } public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 856ff9f1f9..2e6f9e204d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -84,9 +84,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR } """ }.mkString("\n") - val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => - s"private $javaType $variableName = $initialValue;" - }.mkString("\n ") val code = s""" public SpecificOrdering generate($exprType[] expr) { return new SpecificOrdering(expr); @@ -94,11 +91,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR class SpecificOrdering extends ${classOf[BaseOrdering].getName} { - private $exprType[] expressions = null; - $mutableStates + private $exprType[] expressions; + ${declareMutableStates(ctx)} public SpecificOrdering($exprType[] expr) { expressions = expr; + ${initMutableStates(ctx)} } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 9e5a745d51..1dda5992c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -40,9 +40,6 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool protected def create(predicate: Expression): ((InternalRow) => Boolean) = { val ctx = newCodeGenContext() val eval = predicate.gen(ctx) - val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => - s"private $javaType $variableName = $initialValue;" - }.mkString("\n ") val code = s""" public SpecificPredicate generate($exprType[] expr) { return new SpecificPredicate(expr); @@ -50,9 +47,10 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool class SpecificPredicate extends ${classOf[Predicate].getName} { private final $exprType[] expressions; - $mutableStates + ${declareMutableStates(ctx)} public SpecificPredicate($exprType[] expr) { expressions = expr; + ${initMutableStates(ctx)} } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 8f9fcbf810..405d6b0e3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -151,21 +151,18 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { s"""if (!nullBits[$i]) arr[$i] = c$i;""" }.mkString("\n ") - val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => - s"private $javaType $variableName = $initialValue;" - }.mkString("\n ") - val code = s""" public SpecificProjection generate($exprType[] expr) { return new SpecificProjection(expr); } class SpecificProjection extends ${classOf[BaseProjection].getName} { - private $exprType[] expressions = null; - $mutableStates + private $exprType[] expressions; + ${declareMutableStates(ctx)} public SpecificProjection($exprType[] expr) { expressions = expr; + ${initMutableStates(ctx)} } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index a81d545a8e..3a8e8302b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -74,10 +74,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro }""" }.mkString("\n ") - val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => - s"private $javaType $variableName = $initialValue;" - }.mkString("\n ") - val code = s""" private $exprType[] expressions; @@ -90,10 +86,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private UnsafeRow target = new UnsafeRow(); private byte[] buffer = new byte[64]; + ${declareMutableStates(ctx)} - $mutableStates - - public SpecificProjection() {} + public SpecificProjection() { + ${initMutableStates(ctx)} + } // Scala.Function1 need this public Object apply(Object row) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index 65093dc722..822898e561 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -60,9 +60,9 @@ case class Rand(seed: Long) extends RDG(seed) { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val rngTerm = ctx.freshName("rng") - val className = classOf[XORShiftRandom].getCanonicalName + val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, - s"new $className($seed + org.apache.spark.TaskContext.getPartitionId())") + s"$rngTerm = new $className($seed + org.apache.spark.TaskContext.getPartitionId());") ev.isNull = "false" s""" final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextDouble(); @@ -83,9 +83,9 @@ case class Randn(seed: Long) extends RDG(seed) { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val rngTerm = ctx.freshName("rng") - val className = classOf[XORShiftRandom].getCanonicalName + val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, - s"new $className($seed + org.apache.spark.TaskContext.getPartitionId())") + s"$rngTerm = new $className($seed + org.apache.spark.TaskContext.getPartitionId());") ev.isNull = "false" s""" final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextGaussian(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index fec403fe2d..4d8ed08973 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -58,9 +58,9 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") - ctx.addMutableState(ctx.JAVA_LONG, countTerm, "0L") + ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;") ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, - "((long) org.apache.spark.TaskContext.getPartitionId()) << 33") + s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;") ev.isNull = "false" s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index 7c790c549a..43ffc9cc84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -41,7 +41,8 @@ private[sql] case object SparkPartitionID extends LeafExpression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val idTerm = ctx.freshName("partitionId") - ctx.addMutableState(ctx.JAVA_INT, idTerm, "org.apache.spark.TaskContext.getPartitionId()") + ctx.addMutableState(ctx.JAVA_INT, idTerm, + s"$idTerm = org.apache.spark.TaskContext.getPartitionId();") ev.isNull = "false" s"final ${ctx.javaType(dataType)} ${ev.primitive} = $idTerm;" } |