diff options
author | Reynold Xin <rxin@databricks.com> | 2015-08-18 22:08:15 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-08-18 22:08:15 -0700 |
commit | 1ff0580eda90f9247a5233809667f5cebaea290e (patch) | |
tree | 29d85b57bf0870f316f510c9823f5805fe31dc2a | |
parent | 270ee677750a1f2adaf24b5816857194e61782ff (diff) | |
download | spark-1ff0580eda90f9247a5233809667f5cebaea290e.tar.gz spark-1ff0580eda90f9247a5233809667f5cebaea290e.tar.bz2 spark-1ff0580eda90f9247a5233809667f5cebaea290e.zip |
[SPARK-10093] [SPARK-10096] [SQL] Avoid transformation on executors & fix UDFs on complex types
This is kind of a weird case, but given a sufficiently complex query plan (in this case a TungstenProject with an Exchange underneath), we could have NPEs on the executors due to the time when we were calling transformAllExpressions
In general we should ensure that all transformations occur on the driver and not on the executors. Some reasons for avoid executor side transformations include:
* (this case) Some operator constructors require state such as access to the Spark/SQL conf so doing a makeCopy on the executor can fail.
* (unrelated reason for avoid executor transformations) ExprIds are calculated using an atomic integer, so you can violate their uniqueness constraint by constructing them anywhere other than the driver.
This subsumes #8285.
Author: Reynold Xin <rxin@databricks.com>
Author: Michael Armbrust <michael@databricks.com>
Closes #8295 from rxin/SPARK-10096.
4 files changed, 68 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 298aee3499..1c54671973 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -206,7 +206,9 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { override def nullable: Boolean = false - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + override def eval(input: InternalRow): Any = { + InternalRow(children.map(_.eval(input)): _*) + } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = GenerateUnsafeProjection.createCode(ctx, children) @@ -244,7 +246,9 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression override def nullable: Boolean = false - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + override def eval(input: InternalRow): Any = { + InternalRow(valExprs.map(_.eval(input)): _*) + } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 77b98064a9..3f68b05a24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -75,14 +75,16 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) + /** Rewrite the project list to use unsafe expressions as needed. */ + protected val unsafeProjectList = projectList.map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + protected override def doExecute(): RDD[InternalRow] = { val numRows = longMetric("numRows") child.execute().mapPartitions { iter => - this.transformAllExpressions { - case CreateStruct(children) => CreateStructUnsafe(children) - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - } - val project = UnsafeProjection.create(projectList, child.output) + val project = UnsafeProjection.create(unsafeProjectList, child.output) iter.map { row => numRows += 1 project(row) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala new file mode 100644 index 0000000000..3c359dd840 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +/** + * A test suite to test DataFrame/SQL functionalities with complex types (i.e. array, struct, map). + */ +class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("UDF on struct") { + val f = udf((a: String) => a) + val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.select(struct($"a").as("s")).select(f($"s.a")).collect() + } + + test("UDF on named_struct") { + val f = udf((a: String) => a) + val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.selectExpr("named_struct('a', a) s").select(f($"s.a")).collect() + } + + test("UDF on array") { + val f = udf((a: String) => a) + val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1e2aaae52c..284fff1840 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -878,4 +878,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = Seq(("x", (1, 1)), ("y", (2, 2))).toDF("a", "b") checkAnswer(df.groupBy("b._1").agg(sum("b._2")), Row(1, 1) :: Row(2, 2) :: Nil) } + + test("SPARK-10093: Avoid transformations on executors") { + val df = Seq((1, 1)).toDF("a", "b") + df.where($"a" === 1) + .select($"a", $"b", struct($"b")) + .orderBy("a") + .select(struct($"b")) + .collect() + } } |