aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-04-06 12:09:10 +0800
committerWenchen Fan <wenchen@databricks.com>2016-04-06 12:09:10 +0800
commitf6456fa80ba442bfd7ce069fc23b7dbd993e6cb9 (patch)
tree2df64f23addd73b5d79988f4a9e7cf4ed188b8f6
parent8e5c1cbf2c3d5eaa7d9dd35def177414a0d4cf82 (diff)
downloadspark-f6456fa80ba442bfd7ce069fc23b7dbd993e6cb9.tar.gz
spark-f6456fa80ba442bfd7ce069fc23b7dbd993e6cb9.tar.bz2
spark-f6456fa80ba442bfd7ce069fc23b7dbd993e6cb9.zip
[SPARK-14296][SQL] whole stage codegen support for Dataset.map
## What changes were proposed in this pull request? This PR adds a new operator `MapElements` for `Dataset.map`, it's a 1-1 mapping and is easier to adapt to whole stage codegen framework. ## How was this patch tested? new test in `WholeStageCodegenSuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #12087 from cloud-fan/map.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala40
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala69
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala86
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala14
11 files changed, 247 insertions, 41 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index b2f362b6b8..4ec43aba02 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -345,7 +345,7 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
* @param inputAttributes The input attributes used to resolve deserializer expression, can be empty
* if we want to resolve deserializer by children output.
*/
-case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute])
+case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute] = Nil)
extends UnaryExpression with Unevaluable with NonSQLExpression {
// The input attributes used to resolve deserializer expression must be all resolved.
require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index eebd43dae9..a0490e1351 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -119,18 +119,18 @@ case class Invoke(
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
- lazy val method = targetObject.dataType match {
+ @transient lazy val method = targetObject.dataType match {
case ObjectType(cls) =>
- cls
- .getMethods
- .find(_.getName == functionName)
- .getOrElse(sys.error(s"Couldn't find $functionName on $cls"))
- .getReturnType
- .getName
- case _ => ""
+ val m = cls.getMethods.find(_.getName == functionName)
+ if (m.isEmpty) {
+ sys.error(s"Couldn't find $functionName on $cls")
+ } else {
+ m
+ }
+ case _ => None
}
- lazy val unboxer = (dataType, method) match {
+ lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match {
case (IntegerType, "java.lang.Object") => (s: String) =>
s"((java.lang.Integer)$s).intValue()"
case (LongType, "java.lang.Object") => (s: String) =>
@@ -157,21 +157,31 @@ case class Invoke(
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
- s"${ev.isNull} = ${ev.value} == null;"
+ s"boolean ${ev.isNull} = ${ev.value} == null;"
} else {
+ ev.isNull = obj.isNull
""
}
val value = unboxer(s"${obj.value}.$functionName($argString)")
+ val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) {
+ s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;"
+ } else {
+ s"""
+ $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ try {
+ ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value;
+ } catch (Exception e) {
+ org.apache.spark.unsafe.Platform.throwException(e);
+ }
+ """
+ }
+
s"""
${obj.code}
${argGen.map(_.code).mkString("\n")}
-
- boolean ${ev.isNull} = ${obj.isNull};
- $javaType ${ev.value} =
- ${ev.isNull} ?
- ${ctx.defaultValue(dataType)} : ($javaType) $value;
+ $evaluate
$objNullCheck
"""
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 69b09bcb35..c085a377ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -136,6 +136,7 @@ object SamplePushDown extends Rule[LogicalPlan] {
* representation of data item. For example back to back map operations.
*/
object EliminateSerialization extends Rule[LogicalPlan] {
+ // TODO: find a more general way to do this optimization.
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case m @ MapPartitions(_, deserializer, _, child: ObjectOperator)
if !deserializer.isInstanceOf[Attribute] &&
@@ -144,6 +145,14 @@ object EliminateSerialization extends Rule[LogicalPlan] {
m.copy(
deserializer = childWithoutSerialization.output.head,
child = childWithoutSerialization)
+
+ case m @ MapElements(_, deserializer, _, child: ObjectOperator)
+ if !deserializer.isInstanceOf[Attribute] &&
+ deserializer.dataType == child.outputObject.dataType =>
+ val childWithoutSerialization = child.withObjectOutput
+ m.copy(
+ deserializer = childWithoutSerialization.output.head,
+ child = childWithoutSerialization)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 58313c7b72..ec33a538a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -65,7 +65,7 @@ object MapPartitions {
child: LogicalPlan): MapPartitions = {
MapPartitions(
func.asInstanceOf[Iterator[Any] => Iterator[Any]],
- UnresolvedDeserializer(encoderFor[T].deserializer, Nil),
+ UnresolvedDeserializer(encoderFor[T].deserializer),
encoderFor[U].namedExpressions,
child)
}
@@ -83,6 +83,30 @@ case class MapPartitions(
serializer: Seq[NamedExpression],
child: LogicalPlan) extends UnaryNode with ObjectOperator
+object MapElements {
+ def apply[T : Encoder, U : Encoder](
+ func: AnyRef,
+ child: LogicalPlan): MapElements = {
+ MapElements(
+ func,
+ UnresolvedDeserializer(encoderFor[T].deserializer),
+ encoderFor[U].namedExpressions,
+ child)
+ }
+}
+
+/**
+ * A relation produced by applying `func` to each element of the `child`.
+ *
+ * @param deserializer used to extract the input to `func` from an input row.
+ * @param serializer use to serialize the output of `func`.
+ */
+case class MapElements(
+ func: AnyRef,
+ deserializer: Expression,
+ serializer: Seq[NamedExpression],
+ child: LogicalPlan) extends UnaryNode with ObjectOperator
+
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
def apply[T : Encoder, U : Encoder](
@@ -90,7 +114,7 @@ object AppendColumns {
child: LogicalPlan): AppendColumns = {
new AppendColumns(
func.asInstanceOf[Any => Any],
- UnresolvedDeserializer(encoderFor[T].deserializer, Nil),
+ UnresolvedDeserializer(encoderFor[T].deserializer),
encoderFor[U].namedExpressions,
child)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index f472a5068e..2854d5f9da 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -766,7 +766,8 @@ class Dataset[T] private[sql](
implicit val tuple2Encoder: Encoder[(T, U)] =
ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
- withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) =>
+
+ withTypedPlan {
Project(
leftData :: rightData :: Nil,
joined.analyzed)
@@ -1900,7 +1901,9 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
+ def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan {
+ MapElements[T, U](func, logicalPlan)
+ }
/**
* :: Experimental ::
@@ -1911,8 +1914,10 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
- map(t => func.call(t))(encoder)
+ def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+ implicit val uEnc = encoder
+ withTypedPlan(MapElements[T, U](func, logicalPlan))
+ }
/**
* :: Experimental ::
@@ -2412,12 +2417,7 @@ class Dataset[T] private[sql](
}
/** A convenient function to wrap a logical plan and produce a Dataset. */
- @inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = {
- new Dataset[T](sqlContext, logicalPlan, encoder)
+ @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = {
+ Dataset(sqlContext, logicalPlan)
}
-
- private[sql] def withTypedPlan[R](
- other: Dataset[_], encoder: Encoder[R])(
- f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
- new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index e52f05a5f4..5f3128d8e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -341,6 +341,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.MapPartitions(f, in, out, child) =>
execution.MapPartitions(f, in, out, planLater(child)) :: Nil
+ case logical.MapElements(f, in, out, child) =>
+ execution.MapElements(f, in, out, planLater(child)) :: Nil
case logical.AppendColumns(f, in, out, child) =>
execution.AppendColumns(f, in, out, planLater(child)) :: Nil
case logical.MapGroups(f, key, in, out, grouping, data, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 9f539c4929..4e75a3a794 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -152,7 +152,7 @@ trait CodegenSupport extends SparkPlan {
s"""
|
|/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */
- |${evaluated}
+ |$evaluated
|${parent.doConsume(ctx, inputVars, rowVar)}
""".stripMargin
}
@@ -169,20 +169,20 @@ trait CodegenSupport extends SparkPlan {
/**
* Returns source code to evaluate the variables for required attributes, and clear the code
- * of evaluated variables, to prevent them to be evaluated twice..
+ * of evaluated variables, to prevent them to be evaluated twice.
*/
protected def evaluateRequiredVariables(
attributes: Seq[Attribute],
variables: Seq[ExprCode],
required: AttributeSet): String = {
- var evaluateVars = ""
+ val evaluateVars = new StringBuilder
variables.zipWithIndex.foreach { case (ev, i) =>
if (ev.code != "" && required.contains(attributes(i))) {
- evaluateVars += ev.code.trim + "\n"
+ evaluateVars.append(ev.code.trim + "\n")
ev.code = ""
}
}
- evaluateVars
+ evaluateVars.toString()
}
/**
@@ -305,7 +305,6 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
def doCodeGen(): (CodegenContext, String) = {
val ctx = new CodegenContext
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
- val references = ctx.references.toArray
val source = s"""
public Object generate(Object[] references) {
return new GeneratedIterator(references);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 582dda8603..f48f3f09c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -17,10 +17,13 @@
package org.apache.spark.sql.execution
+import scala.language.existentials
+
+import org.apache.spark.api.java.function.MapFunction
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types.ObjectType
@@ -68,6 +71,70 @@ case class MapPartitions(
}
/**
+ * Applies the given function to each input row and encodes the result.
+ *
+ * Note that, each serializer expression needs the result object which is returned by the given
+ * function, as input. This operator uses some tricks to make sure we only calculate the result
+ * object once. We don't use [[Project]] directly as subexpression elimination doesn't work with
+ * whole stage codegen and it's confusing to show the un-common-subexpression-eliminated version of
+ * a project while explain.
+ */
+case class MapElements(
+ func: AnyRef,
+ deserializer: Expression,
+ serializer: Seq[NamedExpression],
+ child: SparkPlan) extends UnaryNode with ObjectOperator with CodegenSupport {
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+ val (funcClass, methodName) = func match {
+ case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
+ case _ => classOf[Any => Any] -> "apply"
+ }
+ val funcObj = Literal.create(func, ObjectType(funcClass))
+ val resultObjType = serializer.head.collect { case b: BoundReference => b }.head.dataType
+ val callFunc = Invoke(funcObj, methodName, resultObjType, Seq(deserializer))
+
+ val bound = ExpressionCanonicalizer.execute(
+ BindReferences.bindReference(callFunc, child.output))
+ ctx.currentVars = input
+ val evaluated = bound.gen(ctx)
+
+ val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, resultObjType)
+ val outputFields = serializer.map(_ transform {
+ case _: BoundReference => resultObj
+ })
+ val resultVars = outputFields.map(_.gen(ctx))
+ s"""
+ ${evaluated.code}
+ ${consume(ctx, resultVars)}
+ """
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val callFunc: Any => Any = func match {
+ case m: MapFunction[_, _] => i => m.asInstanceOf[MapFunction[Any, Any]].call(i)
+ case _ => func.asInstanceOf[Any => Any]
+ }
+ child.execute().mapPartitionsInternal { iter =>
+ val getObject = generateToObject(deserializer, child.output)
+ val outputObject = generateToRow(serializer)
+ iter.map(row => outputObject(callFunc(getObject(row))))
+ }
+ }
+
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+}
+
+/**
* Applies the given function to each input row, appending the encoded result at the end of the row.
*/
case class AppendColumns(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
new file mode 100644
index 0000000000..6eb952445f
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.types.StringType
+import org.apache.spark.util.Benchmark
+
+/**
+ * Benchmark for Dataset typed operations comparing with DataFrame and RDD versions.
+ */
+object DatasetBenchmark {
+
+ case class Data(l: Long, s: String)
+
+ def main(args: Array[String]): Unit = {
+ val sparkContext = new SparkContext("local[*]", "Dataset benchmark")
+ val sqlContext = new SQLContext(sparkContext)
+
+ import sqlContext.implicits._
+
+ val numRows = 10000000
+ val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
+ val numChains = 10
+
+ val benchmark = new Benchmark("back-to-back map", numRows)
+
+ val func = (d: Data) => Data(d.l + 1, d.s)
+ benchmark.addCase("Dataset") { iter =>
+ var res = df.as[Data]
+ var i = 0
+ while (i < numChains) {
+ res = res.map(func)
+ i += 1
+ }
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ benchmark.addCase("DataFrame") { iter =>
+ var res = df
+ var i = 0
+ while (i < numChains) {
+ res = res.select($"l" + 1 as "l")
+ i += 1
+ }
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ val rdd = sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+ benchmark.addCase("RDD") { iter =>
+ var res = rdd
+ var i = 0
+ while (i < numChains) {
+ res = rdd.map(func)
+ i += 1
+ }
+ res.foreach(_ => Unit)
+ }
+
+ /*
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
+ Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
+ back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Dataset 902 / 995 11.1 90.2 1.0X
+ DataFrame 132 / 167 75.5 13.2 6.8X
+ RDD 216 / 237 46.3 21.6 4.2X
+ */
+ benchmark.run()
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index f7f3bd78e9..4e62fac919 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -198,10 +198,7 @@ abstract class QueryTest extends PlanTest {
val logicalPlan = df.queryExecution.analyzed
// bypass some cases that we can't handle currently.
logicalPlan.transform {
- case _: MapPartitions => return
- case _: MapGroups => return
- case _: AppendColumns => return
- case _: CoGroup => return
+ case _: ObjectOperator => return
case _: LogicalRelation => return
}.transformAllExpressions {
case a: ImperativeAggregate => return
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 6d5be0b5dd..f73ca887f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -17,7 +17,8 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.Row
+import org.apache.spark.api.java.function.MapFunction
+import org.apache.spark.sql.{Encoders, Row}
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
@@ -70,4 +71,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined)
assert(df.collect() === Array(Row(1), Row(2), Row(3)))
}
+
+ test("MapElements should be included in WholeStageCodegen") {
+ import testImplicits._
+
+ val ds = sqlContext.range(10).map(_.toString)
+ val plan = ds.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined)
+ assert(ds.collect() === 0.until(10).map(_.toString).toArray)
+ }
}