aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-12-10 15:11:13 +0800
committerCheng Lian <lian@databricks.com>2015-12-10 15:11:13 +0800
commitd8ec081c911a040f3fb523a68025928ae4afc906 (patch)
tree2f6c19a26369344f2a31285106a75a9fde9182f1 /sql/catalyst
parentbd2cd4f53d1ca10f4896bd39b0e180d4929867a2 (diff)
downloadspark-d8ec081c911a040f3fb523a68025928ae4afc906.tar.gz
spark-d8ec081c911a040f3fb523a68025928ae4afc906.tar.bz2
spark-d8ec081c911a040f3fb523a68025928ae4afc906.zip
[SPARK-12252][SPARK-12131][SQL] refactor MapObjects to make it less hacky
in https://github.com/apache/spark/pull/10133 we found that, we shoud ensure the children of `TreeNode` are all accessible in the `productIterator`, or the behavior will be very confusing. In this PR, I try to fix this problem by expsing the `loopVar`. This also fixes SPARK-12131 which is caused by the hacky `MapObjects`. Author: Wenchen Fan <wenchen@databricks.com> Closes #10239 from cloud-fan/map-objects.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala75
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala1
4 files changed, 35 insertions, 47 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 9b6b5b8bd1..9013fd050b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -414,10 +414,6 @@ object ScalaReflection extends ScalaReflection {
} else {
val clsName = getClassNameFromType(elementType)
val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
- // `MapObjects` will run `extractorFor` lazily, we need to eagerly call `extractorFor` here
- // to trigger the type check.
- extractorFor(inputObject, elementType, newPath)
-
MapObjects(extractorFor(_, elementType, newPath), input, externalDataType)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 67518f52d4..d34ec9408a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -193,7 +193,7 @@ object RowEncoder {
case ArrayType(et, nullable) =>
val arrayData =
Invoke(
- MapObjects(constructorFor, input, et),
+ MapObjects(constructorFor(_), input, et),
"array",
ObjectType(classOf[Array[_]]))
StaticInvoke(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index e6ab9a31be..b2facfda24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -326,19 +326,28 @@ case class WrapOption(child: Expression)
* A place holder for the loop variable used in [[MapObjects]]. This should never be constructed
* manually, but will instead be passed into the provided lambda function.
*/
-case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends Expression {
+case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression
+ with Unevaluable {
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
- throw new UnsupportedOperationException("Only calling gen() is supported.")
+ override def nullable: Boolean = true
- override def children: Seq[Expression] = Nil
- override def gen(ctx: CodeGenContext): GeneratedExpressionCode =
+ override def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
GeneratedExpressionCode(code = "", value = value, isNull = isNull)
+ }
+}
- override def nullable: Boolean = false
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+object MapObjects {
+ private val curId = new java.util.concurrent.atomic.AtomicInteger()
+ def apply(
+ function: Expression => Expression,
+ inputData: Expression,
+ elementType: DataType): MapObjects = {
+ val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
+ val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
+ val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
+ MapObjects(loopVar, function(loopVar), inputData)
+ }
}
/**
@@ -349,20 +358,16 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext
* The following collection ObjectTypes are currently supported:
* Seq, Array, ArrayData, java.util.List
*
- * @param function A function that returns an expression, given an attribute that can be used
- * to access the current value. This is does as a lambda function so that
- * a unique attribute reference can be provided for each expression (thus allowing
- * us to nest multiple MapObject calls).
+ * @param loopVar A place holder that used as the loop variable when iterate the collection, and
+ * used as input for the `lambdaFunction`. It also carries the element type info.
+ * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
+ * to handle collection elements.
* @param inputData An expression that when evaluted returns a collection object.
- * @param elementType The type of element in the collection, expressed as a DataType.
*/
case class MapObjects(
- function: AttributeReference => Expression,
- inputData: Expression,
- elementType: DataType) extends Expression {
-
- private lazy val loopAttribute = AttributeReference("loopVar", elementType)()
- private lazy val completeFunction = function(loopAttribute)
+ loopVar: LambdaVariable,
+ lambdaFunction: Expression,
+ inputData: Expression) extends Expression {
private def itemAccessorMethod(dataType: DataType): String => String = dataType match {
case NullType =>
@@ -402,37 +407,23 @@ case class MapObjects(
override def nullable: Boolean = true
- override def children: Seq[Expression] = completeFunction :: inputData :: Nil
+ override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
- override def dataType: DataType = ArrayType(completeFunction.dataType)
+ override def dataType: DataType = ArrayType(lambdaFunction.dataType)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val javaType = ctx.javaType(dataType)
- val elementJavaType = ctx.javaType(elementType)
+ val elementJavaType = ctx.javaType(loopVar.dataType)
val genInputData = inputData.gen(ctx)
-
- // Variables to hold the element that is currently being processed.
- val loopValue = ctx.freshName("loopValue")
- val loopIsNull = ctx.freshName("loopIsNull")
-
- val loopVariable = LambdaVariable(loopValue, loopIsNull, elementType)
- val substitutedFunction = completeFunction transform {
- case a: AttributeReference if a == loopAttribute => loopVariable
- }
- // A hack to run this through the analyzer (to bind extractions).
- val boundFunction =
- SimpleAnalyzer.execute(Project(Alias(substitutedFunction, "")() :: Nil, LocalRelation(Nil)))
- .expressions.head.children.head
-
- val genFunction = boundFunction.gen(ctx)
+ val genFunction = lambdaFunction.gen(ctx)
val dataLength = ctx.freshName("dataLength")
val convertedArray = ctx.freshName("convertedArray")
val loopIndex = ctx.freshName("loopIndex")
- val convertedType = ctx.boxedType(boundFunction.dataType)
+ val convertedType = ctx.boxedType(lambdaFunction.dataType)
// Because of the way Java defines nested arrays, we have to handle the syntax specially.
// Specifically, we have to insert the [$dataLength] in between the type and any extra nested
@@ -446,9 +437,9 @@ case class MapObjects(
}
val loopNullCheck = if (primitiveElement) {
- s"boolean $loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
+ s"boolean ${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
} else {
- s"boolean $loopIsNull = ${genInputData.isNull} || $loopValue == null;"
+ s"boolean ${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;"
}
s"""
@@ -464,11 +455,11 @@ case class MapObjects(
int $loopIndex = 0;
while ($loopIndex < $dataLength) {
- $elementJavaType $loopValue =
+ $elementJavaType ${loopVar.value} =
($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
$loopNullCheck
- if ($loopIsNull) {
+ if (${loopVar.isNull}) {
$convertedArray[$loopIndex] = null;
} else {
${genFunction.code}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index d6ca138672..7233e0f1b5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -145,6 +145,7 @@ class ExpressionEncoderSuite extends SparkFunSuite {
case class InnerClass(i: Int)
productTest(InnerClass(1))
+ encodeDecodeTest(Array(InnerClass(1)), "array of inner class")
productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))