aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-08-18 22:08:15 -0700
committerReynold Xin <rxin@databricks.com>2015-08-18 22:08:15 -0700
commit1ff0580eda90f9247a5233809667f5cebaea290e (patch)
tree29d85b57bf0870f316f510c9823f5805fe31dc2a
parent270ee677750a1f2adaf24b5816857194e61782ff (diff)
downloadspark-1ff0580eda90f9247a5233809667f5cebaea290e.tar.gz
spark-1ff0580eda90f9247a5233809667f5cebaea290e.tar.bz2
spark-1ff0580eda90f9247a5233809667f5cebaea290e.zip
[SPARK-10093] [SPARK-10096] [SQL] Avoid transformation on executors & fix UDFs on complex types
This is kind of a weird case, but given a sufficiently complex query plan (in this case a TungstenProject with an Exchange underneath), we could have NPEs on the executors due to the time when we were calling transformAllExpressions In general we should ensure that all transformations occur on the driver and not on the executors. Some reasons for avoid executor side transformations include: * (this case) Some operator constructors require state such as access to the Spark/SQL conf so doing a makeCopy on the executor can fail. * (unrelated reason for avoid executor transformations) ExprIds are calculated using an atomic integer, so you can violate their uniqueness constraint by constructing them anywhere other than the driver. This subsumes #8285. Author: Reynold Xin <rxin@databricks.com> Author: Michael Armbrust <michael@databricks.com> Closes #8295 from rxin/SPARK-10096.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala9
4 files changed, 68 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 298aee3499..1c54671973 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -206,7 +206,9 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression {
override def nullable: Boolean = false
- override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
+ override def eval(input: InternalRow): Any = {
+ InternalRow(children.map(_.eval(input)): _*)
+ }
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = GenerateUnsafeProjection.createCode(ctx, children)
@@ -244,7 +246,9 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression
override def nullable: Boolean = false
- override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
+ override def eval(input: InternalRow): Any = {
+ InternalRow(valExprs.map(_.eval(input)): _*)
+ }
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = GenerateUnsafeProjection.createCode(ctx, valExprs)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 77b98064a9..3f68b05a24 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -75,14 +75,16 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan)
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
+ /** Rewrite the project list to use unsafe expressions as needed. */
+ protected val unsafeProjectList = projectList.map(_ transform {
+ case CreateStruct(children) => CreateStructUnsafe(children)
+ case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
+ })
+
protected override def doExecute(): RDD[InternalRow] = {
val numRows = longMetric("numRows")
child.execute().mapPartitions { iter =>
- this.transformAllExpressions {
- case CreateStruct(children) => CreateStructUnsafe(children)
- case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
- }
- val project = UnsafeProjection.create(projectList, child.output)
+ val project = UnsafeProjection.create(unsafeProjectList, child.output)
iter.map { row =>
numRows += 1
project(row)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
new file mode 100644
index 0000000000..3c359dd840
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+/**
+ * A test suite to test DataFrame/SQL functionalities with complex types (i.e. array, struct, map).
+ */
+class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
+
+ test("UDF on struct") {
+ val f = udf((a: String) => a)
+ val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
+ df.select(struct($"a").as("s")).select(f($"s.a")).collect()
+ }
+
+ test("UDF on named_struct") {
+ val f = udf((a: String) => a)
+ val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
+ df.selectExpr("named_struct('a', a) s").select(f($"s.a")).collect()
+ }
+
+ test("UDF on array") {
+ val f = udf((a: String) => a)
+ val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
+ df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect()
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 1e2aaae52c..284fff1840 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -878,4 +878,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df = Seq(("x", (1, 1)), ("y", (2, 2))).toDF("a", "b")
checkAnswer(df.groupBy("b._1").agg(sum("b._2")), Row(1, 1) :: Row(2, 2) :: Nil)
}
+
+ test("SPARK-10093: Avoid transformations on executors") {
+ val df = Seq((1, 1)).toDF("a", "b")
+ df.where($"a" === 1)
+ .select($"a", $"b", struct($"b"))
+ .orderBy("a")
+ .select(struct($"b"))
+ .collect()
+ }
}