aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-06-29 06:39:28 +0800
committerCheng Lian <lian@databricks.com>2016-06-29 06:39:28 +0800
commit8a977b065418f07d2bf4fe1607a5534c32d04c47 (patch)
treec1a233fa275dbfab66afec7b1728e98eb7b2af28 /sql/catalyst/src/main
parent25520e976275e0d1e3bf9c73128ef4dec4618568 (diff)
downloadspark-8a977b065418f07d2bf4fe1607a5534c32d04c47.tar.gz
spark-8a977b065418f07d2bf4fe1607a5534c32d04c47.tar.bz2
spark-8a977b065418f07d2bf4fe1607a5534c32d04c47.zip
[SPARK-16100][SQL] fix bug when use Map as the buffer type of Aggregator
## What changes were proposed in this pull request? The root cause is in `MapObjects`. Its parameter `loopVar` is not declared as child, but sometimes can be same with `lambdaFunction`(e.g. the function that takes `loopVar` and produces `lambdaFunction` may be `identity`), which is a child. This brings trouble when call `withNewChildren`, it may mistakenly treat `loopVar` as a child and cause `IndexOutOfBoundsException: 0` later. This PR fixes this bug by simply pulling out the paremters from `LambdaVariable` and pass them to `MapObjects` directly. ## How was this patch tested? new test in `DatasetAggregatorSuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #13835 from cloud-fan/map-objects.
Diffstat (limited to 'sql/catalyst/src/main')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala28
1 files changed, 17 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index c597a2a709..ea4dee174e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -353,7 +353,7 @@ object MapObjects {
val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
- MapObjects(loopVar, function(loopVar), inputData)
+ MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData)
}
}
@@ -365,14 +365,20 @@ object MapObjects {
* The following collection ObjectTypes are currently supported:
* Seq, Array, ArrayData, java.util.List
*
- * @param loopVar A place holder that used as the loop variable when iterate the collection, and
- * used as input for the `lambdaFunction`. It also carries the element type info.
+ * @param loopValue the name of the loop variable that used when iterate the collection, and used
+ * as input for the `lambdaFunction`
+ * @param loopIsNull the nullity of the loop variable that used when iterate the collection, and
+ * used as input for the `lambdaFunction`
+ * @param loopVarDataType the data type of the loop variable that used when iterate the collection,
+ * and used as input for the `lambdaFunction`
* @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
* to handle collection elements.
* @param inputData An expression that when evaluated returns a collection object.
*/
case class MapObjects private(
- loopVar: LambdaVariable,
+ loopValue: String,
+ loopIsNull: String,
+ loopVarDataType: DataType,
lambdaFunction: Expression,
inputData: Expression) extends Expression with NonSQLExpression {
@@ -386,9 +392,9 @@ case class MapObjects private(
override def dataType: DataType = ArrayType(lambdaFunction.dataType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val elementJavaType = ctx.javaType(loopVar.dataType)
- ctx.addMutableState("boolean", loopVar.isNull, "")
- ctx.addMutableState(elementJavaType, loopVar.value, "")
+ val elementJavaType = ctx.javaType(loopVarDataType)
+ ctx.addMutableState("boolean", loopIsNull, "")
+ ctx.addMutableState(elementJavaType, loopValue, "")
val genInputData = inputData.genCode(ctx)
val genFunction = lambdaFunction.genCode(ctx)
val dataLength = ctx.freshName("dataLength")
@@ -443,11 +449,11 @@ case class MapObjects private(
}
val loopNullCheck = inputData.dataType match {
- case _: ArrayType => s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
+ case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
// The element of primitive array will never be null.
case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive =>
- s"${loopVar.isNull} = false"
- case _ => s"${loopVar.isNull} = ${loopVar.value} == null;"
+ s"$loopIsNull = false"
+ case _ => s"$loopIsNull = $loopValue == null;"
}
val code = s"""
@@ -462,7 +468,7 @@ case class MapObjects private(
int $loopIndex = 0;
while ($loopIndex < $dataLength) {
- ${loopVar.value} = ($elementJavaType) ($getLoopVar);
+ $loopValue = ($elementJavaType) ($getLoopVar);
$loopNullCheck
${genFunction.code}