diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-06-29 06:39:28 +0800 |
---|---|---|
committer | Cheng Lian <lian@databricks.com> | 2016-06-29 06:39:28 +0800 |
commit | 8a977b065418f07d2bf4fe1607a5534c32d04c47 (patch) | |
tree | c1a233fa275dbfab66afec7b1728e98eb7b2af28 /sql/catalyst | |
parent | 25520e976275e0d1e3bf9c73128ef4dec4618568 (diff) | |
download | spark-8a977b065418f07d2bf4fe1607a5534c32d04c47.tar.gz spark-8a977b065418f07d2bf4fe1607a5534c32d04c47.tar.bz2 spark-8a977b065418f07d2bf4fe1607a5534c32d04c47.zip |
[SPARK-16100][SQL] fix bug when use Map as the buffer type of Aggregator
## What changes were proposed in this pull request?
The root cause is in `MapObjects`. Its parameter `loopVar` is not declared as child, but sometimes can be same with `lambdaFunction`(e.g. the function that takes `loopVar` and produces `lambdaFunction` may be `identity`), which is a child. This brings trouble when call `withNewChildren`, it may mistakenly treat `loopVar` as a child and cause `IndexOutOfBoundsException: 0` later.
This PR fixes this bug by simply pulling out the paremters from `LambdaVariable` and pass them to `MapObjects` directly.
## How was this patch tested?
new test in `DatasetAggregatorSuite`
Author: Wenchen Fan <wenchen@databricks.com>
Closes #13835 from cloud-fan/map-objects.
Diffstat (limited to 'sql/catalyst')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala | 28 |
1 files changed, 17 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index c597a2a709..ea4dee174e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -353,7 +353,7 @@ object MapObjects { val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - MapObjects(loopVar, function(loopVar), inputData) + MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData) } } @@ -365,14 +365,20 @@ object MapObjects { * The following collection ObjectTypes are currently supported: * Seq, Array, ArrayData, java.util.List * - * @param loopVar A place holder that used as the loop variable when iterate the collection, and - * used as input for the `lambdaFunction`. It also carries the element type info. + * @param loopValue the name of the loop variable that used when iterate the collection, and used + * as input for the `lambdaFunction` + * @param loopIsNull the nullity of the loop variable that used when iterate the collection, and + * used as input for the `lambdaFunction` + * @param loopVarDataType the data type of the loop variable that used when iterate the collection, + * and used as input for the `lambdaFunction` * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. * @param inputData An expression that when evaluated returns a collection object. */ case class MapObjects private( - loopVar: LambdaVariable, + loopValue: String, + loopIsNull: String, + loopVarDataType: DataType, lambdaFunction: Expression, inputData: Expression) extends Expression with NonSQLExpression { @@ -386,9 +392,9 @@ case class MapObjects private( override def dataType: DataType = ArrayType(lambdaFunction.dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val elementJavaType = ctx.javaType(loopVar.dataType) - ctx.addMutableState("boolean", loopVar.isNull, "") - ctx.addMutableState(elementJavaType, loopVar.value, "") + val elementJavaType = ctx.javaType(loopVarDataType) + ctx.addMutableState("boolean", loopIsNull, "") + ctx.addMutableState(elementJavaType, loopValue, "") val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -443,11 +449,11 @@ case class MapObjects private( } val loopNullCheck = inputData.dataType match { - case _: ArrayType => s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" + case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" // The element of primitive array will never be null. case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => - s"${loopVar.isNull} = false" - case _ => s"${loopVar.isNull} = ${loopVar.value} == null;" + s"$loopIsNull = false" + case _ => s"$loopIsNull = $loopValue == null;" } val code = s""" @@ -462,7 +468,7 @@ case class MapObjects private( int $loopIndex = 0; while ($loopIndex < $dataLength) { - ${loopVar.value} = ($elementJavaType) ($getLoopVar); + $loopValue = ($elementJavaType) ($getLoopVar); $loopNullCheck ${genFunction.code} |