From d8ec081c911a040f3fb523a68025928ae4afc906 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 10 Dec 2015 15:11:13 +0800 Subject: [SPARK-12252][SPARK-12131][SQL] refactor MapObjects to make it less hacky in https://github.com/apache/spark/pull/10133 we found that, we shoud ensure the children of `TreeNode` are all accessible in the `productIterator`, or the behavior will be very confusing. In this PR, I try to fix this problem by expsing the `loopVar`. This also fixes SPARK-12131 which is caused by the hacky `MapObjects`. Author: Wenchen Fan Closes #10239 from cloud-fan/map-objects. --- .../spark/sql/catalyst/ScalaReflection.scala | 4 -- .../spark/sql/catalyst/encoders/RowEncoder.scala | 2 +- .../spark/sql/catalyst/expressions/objects.scala | 75 ++++++++++------------ .../catalyst/encoders/ExpressionEncoderSuite.scala | 1 + 4 files changed, 35 insertions(+), 47 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9b6b5b8bd1..9013fd050b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -414,10 +414,6 @@ object ScalaReflection extends ScalaReflection { } else { val clsName = getClassNameFromType(elementType) val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath - // `MapObjects` will run `extractorFor` lazily, we need to eagerly call `extractorFor` here - // to trigger the type check. - extractorFor(inputObject, elementType, newPath) - MapObjects(extractorFor(_, elementType, newPath), input, externalDataType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 67518f52d4..d34ec9408a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -193,7 +193,7 @@ object RowEncoder { case ArrayType(et, nullable) => val arrayData = Invoke( - MapObjects(constructorFor, input, et), + MapObjects(constructorFor(_), input, et), "array", ObjectType(classOf[Array[_]])) StaticInvoke( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index e6ab9a31be..b2facfda24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -326,19 +326,28 @@ case class WrapOption(child: Expression) * A place holder for the loop variable used in [[MapObjects]]. This should never be constructed * manually, but will instead be passed into the provided lambda function. */ -case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends Expression { +case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression + with Unevaluable { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = - throw new UnsupportedOperationException("Only calling gen() is supported.") + override def nullable: Boolean = true - override def children: Seq[Expression] = Nil - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = + override def gen(ctx: CodeGenContext): GeneratedExpressionCode = { GeneratedExpressionCode(code = "", value = value, isNull = isNull) + } +} - override def nullable: Boolean = false - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") +object MapObjects { + private val curId = new java.util.concurrent.atomic.AtomicInteger() + def apply( + function: Expression => Expression, + inputData: Expression, + elementType: DataType): MapObjects = { + val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() + val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() + val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) + MapObjects(loopVar, function(loopVar), inputData) + } } /** @@ -349,20 +358,16 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext * The following collection ObjectTypes are currently supported: * Seq, Array, ArrayData, java.util.List * - * @param function A function that returns an expression, given an attribute that can be used - * to access the current value. This is does as a lambda function so that - * a unique attribute reference can be provided for each expression (thus allowing - * us to nest multiple MapObject calls). + * @param loopVar A place holder that used as the loop variable when iterate the collection, and + * used as input for the `lambdaFunction`. It also carries the element type info. + * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function + * to handle collection elements. * @param inputData An expression that when evaluted returns a collection object. - * @param elementType The type of element in the collection, expressed as a DataType. */ case class MapObjects( - function: AttributeReference => Expression, - inputData: Expression, - elementType: DataType) extends Expression { - - private lazy val loopAttribute = AttributeReference("loopVar", elementType)() - private lazy val completeFunction = function(loopAttribute) + loopVar: LambdaVariable, + lambdaFunction: Expression, + inputData: Expression) extends Expression { private def itemAccessorMethod(dataType: DataType): String => String = dataType match { case NullType => @@ -402,37 +407,23 @@ case class MapObjects( override def nullable: Boolean = true - override def children: Seq[Expression] = completeFunction :: inputData :: Nil + override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def dataType: DataType = ArrayType(completeFunction.dataType) + override def dataType: DataType = ArrayType(lambdaFunction.dataType) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) - val elementJavaType = ctx.javaType(elementType) + val elementJavaType = ctx.javaType(loopVar.dataType) val genInputData = inputData.gen(ctx) - - // Variables to hold the element that is currently being processed. - val loopValue = ctx.freshName("loopValue") - val loopIsNull = ctx.freshName("loopIsNull") - - val loopVariable = LambdaVariable(loopValue, loopIsNull, elementType) - val substitutedFunction = completeFunction transform { - case a: AttributeReference if a == loopAttribute => loopVariable - } - // A hack to run this through the analyzer (to bind extractions). - val boundFunction = - SimpleAnalyzer.execute(Project(Alias(substitutedFunction, "")() :: Nil, LocalRelation(Nil))) - .expressions.head.children.head - - val genFunction = boundFunction.gen(ctx) + val genFunction = lambdaFunction.gen(ctx) val dataLength = ctx.freshName("dataLength") val convertedArray = ctx.freshName("convertedArray") val loopIndex = ctx.freshName("loopIndex") - val convertedType = ctx.boxedType(boundFunction.dataType) + val convertedType = ctx.boxedType(lambdaFunction.dataType) // Because of the way Java defines nested arrays, we have to handle the syntax specially. // Specifically, we have to insert the [$dataLength] in between the type and any extra nested @@ -446,9 +437,9 @@ case class MapObjects( } val loopNullCheck = if (primitiveElement) { - s"boolean $loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" + s"boolean ${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" } else { - s"boolean $loopIsNull = ${genInputData.isNull} || $loopValue == null;" + s"boolean ${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" } s""" @@ -464,11 +455,11 @@ case class MapObjects( int $loopIndex = 0; while ($loopIndex < $dataLength) { - $elementJavaType $loopValue = + $elementJavaType ${loopVar.value} = ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; $loopNullCheck - if ($loopIsNull) { + if (${loopVar.isNull}) { $convertedArray[$loopIndex] = null; } else { ${genFunction.code} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index d6ca138672..7233e0f1b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -145,6 +145,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { case class InnerClass(i: Int) productTest(InnerClass(1)) + encodeDecodeTest(Array(InnerClass(1)), "array of inner class") productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) -- cgit v1.2.3