diff options
author | Herman van Hovell <hvanhovell@questtec.nl> | 2015-11-06 12:21:53 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-11-06 12:21:53 -0800 |
commit | f328fedafd7bd084470a5e402de0429b5b7f8cd7 (patch) | |
tree | 72bc8976d41d38c0118eac17b64a60b404088861 /sql | |
parent | 49f1a820372d1cba41f3f00d07eb5728f2ed6705 (diff) | |
download | spark-f328fedafd7bd084470a5e402de0429b5b7f8cd7.tar.gz spark-f328fedafd7bd084470a5e402de0429b5b7f8cd7.tar.bz2 spark-f328fedafd7bd084470a5e402de0429b5b7f8cd7.zip |
[SPARK-11450] [SQL] Add Unsafe Row processing to Expand
This PR enables the Expand operator to process and produce Unsafe Rows.
Author: Herman van Hovell <hvanhovell@questtec.nl>
Closes #9414 from hvanhovell/SPARK-11450.
Diffstat (limited to 'sql')
4 files changed, 73 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index a6fe730f6d..79dabe8e92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -128,7 +128,11 @@ object UnsafeProjection { * Returns an UnsafeProjection for given sequence of Expressions (bounded). */ def create(exprs: Seq[Expression]): UnsafeProjection = { - GenerateUnsafeProjection.generate(exprs) + val unsafeExprs = exprs.map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + GenerateUnsafeProjection.generate(unsafeExprs) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index a458881f40..55e95769d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -41,14 +41,21 @@ case class Expand( // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + + private[this] val projection = { + if (outputsUnsafeRows) { + (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) + } else { + (exprs: Seq[Expression]) => newMutableProjection(exprs, child.output)() + } + } + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { child.execute().mapPartitions { iter => - // TODO Move out projection objects creation and transfer to - // workers via closure. However we can't assume the Projection - // is serializable because of the code gen, so we have to - // create the projections within each of the partition processing. - val groups = projections.map(ee => newProjection(ee, child.output)).toArray - + val groups = projections.map(projection).toArray new Iterator[InternalRow] { private[this] var result: InternalRow = _ private[this] var idx = -1 // -1 means the initial state diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index d5a803f8c4..799650a4f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -67,16 +67,10 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - /** Rewrite the project list to use unsafe expressions as needed. */ - protected val unsafeProjectList = projectList.map(_ transform { - case CreateStruct(children) => CreateStructUnsafe(children) - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - protected override def doExecute(): RDD[InternalRow] = { val numRows = longMetric("numRows") child.execute().mapPartitions { iter => - val project = UnsafeProjection.create(unsafeProjectList, child.output) + val project = UnsafeProjection.create(projectList, child.output) iter.map { row => numRows += 1 project(row) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala new file mode 100644 index 0000000000..faef76d52a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.IntegerType + +class ExpandSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder + + private def testExpand(f: SparkPlan => SparkPlan): Unit = { + val input = (1 to 1000).map(Tuple1.apply) + val projections = Seq.tabulate(2) { i => + Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil + } + val attributes = projections.head.map(_.toAttribute) + checkAnswer( + input.toDF(), + plan => Expand(projections, attributes, f(plan)), + input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j))) + ) + } + + test("inheriting child row type") { + val exprs = AttributeReference("a", IntegerType, false)() :: Nil + val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty))) + assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.") + } + + test("expanding UnsafeRows") { + testExpand(ConvertToUnsafe) + } + + test("expanding SafeRows") { + testExpand(identity) + } +} |