aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala13
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala62
6 files changed, 86 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 9816b33ae8..d9f36f7f87 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2230,8 +2230,8 @@ class Analyzer(
val result = resolved transformDown {
case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved =>
inputData.dataType match {
- case ArrayType(et, _) =>
- val expr = MapObjects(func, inputData, et, cls) transformUp {
+ case ArrayType(et, cn) =>
+ val expr = MapObjects(func, inputData, et, cn, cls) transformUp {
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index f446c3e4a7..1a202ecf74 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -451,6 +451,8 @@ object MapObjects {
* @param function The function applied on the collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param elementType The data type of elements in the collection.
+ * @param elementNullable When false, indicating elements in the collection are always
+ * non-null value.
* @param customCollectionCls Class of the resulting collection (returning ObjectType)
* or None (returning ArrayType)
*/
@@ -458,11 +460,12 @@ object MapObjects {
function: Expression => Expression,
inputData: Expression,
elementType: DataType,
+ elementNullable: Boolean = true,
customCollectionCls: Option[Class[_]] = None): MapObjects = {
val id = curId.getAndIncrement()
val loopValue = s"MapObjects_loopValue$id"
val loopIsNull = s"MapObjects_loopIsNull$id"
- val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
+ val loopVar = LambdaVariable(loopValue, loopIsNull, elementType, elementNullable)
MapObjects(
loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index d221b0611a..dd768d18e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -119,7 +119,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
CostBasedJoinReorder(conf)) ::
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates(conf)) ::
- Batch("Typed Filter Optimization", fixedPoint,
+ Batch("Object Expressions Optimization", fixedPoint,
+ EliminateMapObjects,
CombineTypedFilters) ::
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 8445ee06bd..ea2c5d241d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
+import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@@ -368,6 +369,8 @@ case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] {
case EqualNullSafe(Literal(null, _), r) => IsNull(r)
case EqualNullSafe(l, Literal(null, _)) => IsNull(l)
+ case AssertNotNull(c, _) if !c.nullable => c
+
// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
val newChildren = children.filterNot(isNullLiteral)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
index 257dbfac8c..8cdc6425bc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@@ -96,3 +97,15 @@ object CombineTypedFilters extends Rule[LogicalPlan] {
}
}
}
+
+/**
+ * Removes MapObjects when the following conditions are satisfied
+ * 1. Mapobject(... lambdavariable(..., false) ...), which means types for input and output
+ * are primitive types with non-nullable
+ * 2. no custom collection class specified representation of data item.
+ */
+object EliminateMapObjects extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case MapObjects(_, _, _, LambdaVariable(_, _, _, false), inputData, None) => inputData
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala
new file mode 100644
index 0000000000..d4f37e2a5e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.expressions.objects.Invoke
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{DeserializeToObject, LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types._
+
+class EliminateMapObjectsSuite extends PlanTest {
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches = {
+ Batch("EliminateMapObjects", FixedPoint(50),
+ NullPropagation(conf),
+ SimplifyCasts,
+ EliminateMapObjects) :: Nil
+ }
+ }
+
+ implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]()
+ implicit private def doubleArrayEncoder = ExpressionEncoder[Array[Double]]()
+
+ test("SPARK-20254: Remove unnecessary data conversion for primitive array") {
+ val intObjType = ObjectType(classOf[Array[Int]])
+ val intInput = LocalRelation('a.array(ArrayType(IntegerType, false)))
+ val intQuery = intInput.deserialize[Array[Int]].analyze
+ val intOptimized = Optimize.execute(intQuery)
+ val intExpected = DeserializeToObject(
+ Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false),
+ AttributeReference("obj", intObjType, true)(), intInput)
+ comparePlans(intOptimized, intExpected)
+
+ val doubleObjType = ObjectType(classOf[Array[Double]])
+ val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false)))
+ val doubleQuery = doubleInput.deserialize[Array[Double]].analyze
+ val doubleOptimized = Optimize.execute(doubleQuery)
+ val doubleExpected = DeserializeToObject(
+ Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false),
+ AttributeReference("obj", doubleObjType, true)(), doubleInput)
+ comparePlans(doubleOptimized, doubleExpected)
+ }
+}