aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2015-11-06 12:21:53 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-06 12:21:53 -0800
commitf328fedafd7bd084470a5e402de0429b5b7f8cd7 (patch)
tree72bc8976d41d38c0118eac17b64a60b404088861 /sql
parent49f1a820372d1cba41f3f00d07eb5728f2ed6705 (diff)
downloadspark-f328fedafd7bd084470a5e402de0429b5b7f8cd7.tar.gz
spark-f328fedafd7bd084470a5e402de0429b5b7f8cd7.tar.bz2
spark-f328fedafd7bd084470a5e402de0429b5b7f8cd7.zip
[SPARK-11450] [SQL] Add Unsafe Row processing to Expand
This PR enables the Expand operator to process and produce Unsafe Rows. Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #9414 from hvanhovell/SPARK-11450.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala54
4 files changed, 73 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index a6fe730f6d..79dabe8e92 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -128,7 +128,11 @@ object UnsafeProjection {
* Returns an UnsafeProjection for given sequence of Expressions (bounded).
*/
def create(exprs: Seq[Expression]): UnsafeProjection = {
- GenerateUnsafeProjection.generate(exprs)
+ val unsafeExprs = exprs.map(_ transform {
+ case CreateStruct(children) => CreateStructUnsafe(children)
+ case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
+ })
+ GenerateUnsafeProjection.generate(unsafeExprs)
}
def create(expr: Expression): UnsafeProjection = create(Seq(expr))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index a458881f40..55e95769d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -41,14 +41,21 @@ case class Expand(
// as UNKNOWN partitioning
override def outputPartitioning: Partitioning = UnknownPartitioning(0)
+ override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
+ override def canProcessUnsafeRows: Boolean = true
+ override def canProcessSafeRows: Boolean = true
+
+ private[this] val projection = {
+ if (outputsUnsafeRows) {
+ (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output)
+ } else {
+ (exprs: Seq[Expression]) => newMutableProjection(exprs, child.output)()
+ }
+ }
+
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
child.execute().mapPartitions { iter =>
- // TODO Move out projection objects creation and transfer to
- // workers via closure. However we can't assume the Projection
- // is serializable because of the code gen, so we have to
- // create the projections within each of the partition processing.
- val groups = projections.map(ee => newProjection(ee, child.output)).toArray
-
+ val groups = projections.map(projection).toArray
new Iterator[InternalRow] {
private[this] var result: InternalRow = _
private[this] var idx = -1 // -1 means the initial state
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index d5a803f8c4..799650a4f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -67,16 +67,10 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan)
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
- /** Rewrite the project list to use unsafe expressions as needed. */
- protected val unsafeProjectList = projectList.map(_ transform {
- case CreateStruct(children) => CreateStructUnsafe(children)
- case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
- })
-
protected override def doExecute(): RDD[InternalRow] = {
val numRows = longMetric("numRows")
child.execute().mapPartitions { iter =>
- val project = UnsafeProjection.create(unsafeProjectList, child.output)
+ val project = UnsafeProjection.create(projectList, child.output)
iter.map { row =>
numRows += 1
project(row)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala
new file mode 100644
index 0000000000..faef76d52a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.IntegerType
+
+class ExpandSuite extends SparkPlanTest with SharedSQLContext {
+ import testImplicits.localSeqToDataFrameHolder
+
+ private def testExpand(f: SparkPlan => SparkPlan): Unit = {
+ val input = (1 to 1000).map(Tuple1.apply)
+ val projections = Seq.tabulate(2) { i =>
+ Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil
+ }
+ val attributes = projections.head.map(_.toAttribute)
+ checkAnswer(
+ input.toDF(),
+ plan => Expand(projections, attributes, f(plan)),
+ input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j)))
+ )
+ }
+
+ test("inheriting child row type") {
+ val exprs = AttributeReference("a", IntegerType, false)() :: Nil
+ val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty)))
+ assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.")
+ }
+
+ test("expanding UnsafeRows") {
+ testExpand(ConvertToUnsafe)
+ }
+
+ test("expanding SafeRows") {
+ testExpand(identity)
+ }
+}