aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-04-06 12:09:10 +0800
committerWenchen Fan <wenchen@databricks.com>2016-04-06 12:09:10 +0800
commitf6456fa80ba442bfd7ce069fc23b7dbd993e6cb9 (patch)
tree2df64f23addd73b5d79988f4a9e7cf4ed188b8f6 /sql/core
parent8e5c1cbf2c3d5eaa7d9dd35def177414a0d4cf82 (diff)
downloadspark-f6456fa80ba442bfd7ce069fc23b7dbd993e6cb9.tar.gz
spark-f6456fa80ba442bfd7ce069fc23b7dbd993e6cb9.tar.bz2
spark-f6456fa80ba442bfd7ce069fc23b7dbd993e6cb9.zip
[SPARK-14296][SQL] whole stage codegen support for Dataset.map
## What changes were proposed in this pull request? This PR adds a new operator `MapElements` for `Dataset.map`, it's a 1-1 mapping and is easier to adapt to whole stage codegen framework. ## How was this patch tested? new test in `WholeStageCodegenSuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #12087 from cloud-fan/map.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala69
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala86
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala14
7 files changed, 186 insertions, 23 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index f472a5068e..2854d5f9da 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -766,7 +766,8 @@ class Dataset[T] private[sql](
implicit val tuple2Encoder: Encoder[(T, U)] =
ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
- withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) =>
+
+ withTypedPlan {
Project(
leftData :: rightData :: Nil,
joined.analyzed)
@@ -1900,7 +1901,9 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
+ def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan {
+ MapElements[T, U](func, logicalPlan)
+ }
/**
* :: Experimental ::
@@ -1911,8 +1914,10 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
- map(t => func.call(t))(encoder)
+ def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+ implicit val uEnc = encoder
+ withTypedPlan(MapElements[T, U](func, logicalPlan))
+ }
/**
* :: Experimental ::
@@ -2412,12 +2417,7 @@ class Dataset[T] private[sql](
}
/** A convenient function to wrap a logical plan and produce a Dataset. */
- @inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = {
- new Dataset[T](sqlContext, logicalPlan, encoder)
+ @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = {
+ Dataset(sqlContext, logicalPlan)
}
-
- private[sql] def withTypedPlan[R](
- other: Dataset[_], encoder: Encoder[R])(
- f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
- new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index e52f05a5f4..5f3128d8e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -341,6 +341,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.MapPartitions(f, in, out, child) =>
execution.MapPartitions(f, in, out, planLater(child)) :: Nil
+ case logical.MapElements(f, in, out, child) =>
+ execution.MapElements(f, in, out, planLater(child)) :: Nil
case logical.AppendColumns(f, in, out, child) =>
execution.AppendColumns(f, in, out, planLater(child)) :: Nil
case logical.MapGroups(f, key, in, out, grouping, data, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 9f539c4929..4e75a3a794 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -152,7 +152,7 @@ trait CodegenSupport extends SparkPlan {
s"""
|
|/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */
- |${evaluated}
+ |$evaluated
|${parent.doConsume(ctx, inputVars, rowVar)}
""".stripMargin
}
@@ -169,20 +169,20 @@ trait CodegenSupport extends SparkPlan {
/**
* Returns source code to evaluate the variables for required attributes, and clear the code
- * of evaluated variables, to prevent them to be evaluated twice..
+ * of evaluated variables, to prevent them to be evaluated twice.
*/
protected def evaluateRequiredVariables(
attributes: Seq[Attribute],
variables: Seq[ExprCode],
required: AttributeSet): String = {
- var evaluateVars = ""
+ val evaluateVars = new StringBuilder
variables.zipWithIndex.foreach { case (ev, i) =>
if (ev.code != "" && required.contains(attributes(i))) {
- evaluateVars += ev.code.trim + "\n"
+ evaluateVars.append(ev.code.trim + "\n")
ev.code = ""
}
}
- evaluateVars
+ evaluateVars.toString()
}
/**
@@ -305,7 +305,6 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
def doCodeGen(): (CodegenContext, String) = {
val ctx = new CodegenContext
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
- val references = ctx.references.toArray
val source = s"""
public Object generate(Object[] references) {
return new GeneratedIterator(references);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 582dda8603..f48f3f09c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -17,10 +17,13 @@
package org.apache.spark.sql.execution
+import scala.language.existentials
+
+import org.apache.spark.api.java.function.MapFunction
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types.ObjectType
@@ -68,6 +71,70 @@ case class MapPartitions(
}
/**
+ * Applies the given function to each input row and encodes the result.
+ *
+ * Note that, each serializer expression needs the result object which is returned by the given
+ * function, as input. This operator uses some tricks to make sure we only calculate the result
+ * object once. We don't use [[Project]] directly as subexpression elimination doesn't work with
+ * whole stage codegen and it's confusing to show the un-common-subexpression-eliminated version of
+ * a project while explain.
+ */
+case class MapElements(
+ func: AnyRef,
+ deserializer: Expression,
+ serializer: Seq[NamedExpression],
+ child: SparkPlan) extends UnaryNode with ObjectOperator with CodegenSupport {
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+ val (funcClass, methodName) = func match {
+ case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
+ case _ => classOf[Any => Any] -> "apply"
+ }
+ val funcObj = Literal.create(func, ObjectType(funcClass))
+ val resultObjType = serializer.head.collect { case b: BoundReference => b }.head.dataType
+ val callFunc = Invoke(funcObj, methodName, resultObjType, Seq(deserializer))
+
+ val bound = ExpressionCanonicalizer.execute(
+ BindReferences.bindReference(callFunc, child.output))
+ ctx.currentVars = input
+ val evaluated = bound.gen(ctx)
+
+ val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, resultObjType)
+ val outputFields = serializer.map(_ transform {
+ case _: BoundReference => resultObj
+ })
+ val resultVars = outputFields.map(_.gen(ctx))
+ s"""
+ ${evaluated.code}
+ ${consume(ctx, resultVars)}
+ """
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val callFunc: Any => Any = func match {
+ case m: MapFunction[_, _] => i => m.asInstanceOf[MapFunction[Any, Any]].call(i)
+ case _ => func.asInstanceOf[Any => Any]
+ }
+ child.execute().mapPartitionsInternal { iter =>
+ val getObject = generateToObject(deserializer, child.output)
+ val outputObject = generateToRow(serializer)
+ iter.map(row => outputObject(callFunc(getObject(row))))
+ }
+ }
+
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+}
+
+/**
* Applies the given function to each input row, appending the encoded result at the end of the row.
*/
case class AppendColumns(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
new file mode 100644
index 0000000000..6eb952445f
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.types.StringType
+import org.apache.spark.util.Benchmark
+
+/**
+ * Benchmark for Dataset typed operations comparing with DataFrame and RDD versions.
+ */
+object DatasetBenchmark {
+
+ case class Data(l: Long, s: String)
+
+ def main(args: Array[String]): Unit = {
+ val sparkContext = new SparkContext("local[*]", "Dataset benchmark")
+ val sqlContext = new SQLContext(sparkContext)
+
+ import sqlContext.implicits._
+
+ val numRows = 10000000
+ val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
+ val numChains = 10
+
+ val benchmark = new Benchmark("back-to-back map", numRows)
+
+ val func = (d: Data) => Data(d.l + 1, d.s)
+ benchmark.addCase("Dataset") { iter =>
+ var res = df.as[Data]
+ var i = 0
+ while (i < numChains) {
+ res = res.map(func)
+ i += 1
+ }
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ benchmark.addCase("DataFrame") { iter =>
+ var res = df
+ var i = 0
+ while (i < numChains) {
+ res = res.select($"l" + 1 as "l")
+ i += 1
+ }
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ val rdd = sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+ benchmark.addCase("RDD") { iter =>
+ var res = rdd
+ var i = 0
+ while (i < numChains) {
+ res = rdd.map(func)
+ i += 1
+ }
+ res.foreach(_ => Unit)
+ }
+
+ /*
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
+ Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
+ back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Dataset 902 / 995 11.1 90.2 1.0X
+ DataFrame 132 / 167 75.5 13.2 6.8X
+ RDD 216 / 237 46.3 21.6 4.2X
+ */
+ benchmark.run()
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index f7f3bd78e9..4e62fac919 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -198,10 +198,7 @@ abstract class QueryTest extends PlanTest {
val logicalPlan = df.queryExecution.analyzed
// bypass some cases that we can't handle currently.
logicalPlan.transform {
- case _: MapPartitions => return
- case _: MapGroups => return
- case _: AppendColumns => return
- case _: CoGroup => return
+ case _: ObjectOperator => return
case _: LogicalRelation => return
}.transformAllExpressions {
case a: ImperativeAggregate => return
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 6d5be0b5dd..f73ca887f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -17,7 +17,8 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.Row
+import org.apache.spark.api.java.function.MapFunction
+import org.apache.spark.sql.{Encoders, Row}
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
@@ -70,4 +71,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined)
assert(df.collect() === Array(Row(1), Row(2), Row(3)))
}
+
+ test("MapElements should be included in WholeStageCodegen") {
+ import testImplicits._
+
+ val ds = sqlContext.range(10).map(_.toString)
+ val plan = ds.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined)
+ assert(ds.collect() === 0.until(10).map(_.toString).toArray)
+ }
}