diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-04-06 12:09:10 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-04-06 12:09:10 +0800 |
commit | f6456fa80ba442bfd7ce069fc23b7dbd993e6cb9 (patch) | |
tree | 2df64f23addd73b5d79988f4a9e7cf4ed188b8f6 /sql | |
parent | 8e5c1cbf2c3d5eaa7d9dd35def177414a0d4cf82 (diff) | |
download | spark-f6456fa80ba442bfd7ce069fc23b7dbd993e6cb9.tar.gz spark-f6456fa80ba442bfd7ce069fc23b7dbd993e6cb9.tar.bz2 spark-f6456fa80ba442bfd7ce069fc23b7dbd993e6cb9.zip |
[SPARK-14296][SQL] whole stage codegen support for Dataset.map
## What changes were proposed in this pull request?
This PR adds a new operator `MapElements` for `Dataset.map`, it's a 1-1 mapping and is easier to adapt to whole stage codegen framework.
## How was this patch tested?
new test in `WholeStageCodegenSuite`
Author: Wenchen Fan <wenchen@databricks.com>
Closes #12087 from cloud-fan/map.
Diffstat (limited to 'sql')
11 files changed, 247 insertions, 41 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index b2f362b6b8..4ec43aba02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -345,7 +345,7 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) * @param inputAttributes The input attributes used to resolve deserializer expression, can be empty * if we want to resolve deserializer by children output. */ -case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute]) +case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute] = Nil) extends UnaryExpression with Unevaluable with NonSQLExpression { // The input attributes used to resolve deserializer expression must be all resolved. require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index eebd43dae9..a0490e1351 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -119,18 +119,18 @@ case class Invoke( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - lazy val method = targetObject.dataType match { + @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => - cls - .getMethods - .find(_.getName == functionName) - .getOrElse(sys.error(s"Couldn't find $functionName on $cls")) - .getReturnType - .getName - case _ => "" + val m = cls.getMethods.find(_.getName == functionName) + if (m.isEmpty) { + sys.error(s"Couldn't find $functionName on $cls") + } else { + m + } + case _ => None } - lazy val unboxer = (dataType, method) match { + lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match { case (IntegerType, "java.lang.Object") => (s: String) => s"((java.lang.Integer)$s).intValue()" case (LongType, "java.lang.Object") => (s: String) => @@ -157,21 +157,31 @@ case class Invoke( // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. val objNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" + s"boolean ${ev.isNull} = ${ev.value} == null;" } else { + ev.isNull = obj.isNull "" } val value = unboxer(s"${obj.value}.$functionName($argString)") + val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) { + s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;" + } else { + s""" + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + try { + ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value; + } catch (Exception e) { + org.apache.spark.unsafe.Platform.throwException(e); + } + """ + } + s""" ${obj.code} ${argGen.map(_.code).mkString("\n")} - - boolean ${ev.isNull} = ${obj.isNull}; - $javaType ${ev.value} = - ${ev.isNull} ? - ${ctx.defaultValue(dataType)} : ($javaType) $value; + $evaluate $objNullCheck """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 69b09bcb35..c085a377ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -136,6 +136,7 @@ object SamplePushDown extends Rule[LogicalPlan] { * representation of data item. For example back to back map operations. */ object EliminateSerialization extends Rule[LogicalPlan] { + // TODO: find a more general way to do this optimization. def apply(plan: LogicalPlan): LogicalPlan = plan transform { case m @ MapPartitions(_, deserializer, _, child: ObjectOperator) if !deserializer.isInstanceOf[Attribute] && @@ -144,6 +145,14 @@ object EliminateSerialization extends Rule[LogicalPlan] { m.copy( deserializer = childWithoutSerialization.output.head, child = childWithoutSerialization) + + case m @ MapElements(_, deserializer, _, child: ObjectOperator) + if !deserializer.isInstanceOf[Attribute] && + deserializer.dataType == child.outputObject.dataType => + val childWithoutSerialization = child.withObjectOutput + m.copy( + deserializer = childWithoutSerialization.output.head, + child = childWithoutSerialization) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 58313c7b72..ec33a538a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -65,7 +65,7 @@ object MapPartitions { child: LogicalPlan): MapPartitions = { MapPartitions( func.asInstanceOf[Iterator[Any] => Iterator[Any]], - UnresolvedDeserializer(encoderFor[T].deserializer, Nil), + UnresolvedDeserializer(encoderFor[T].deserializer), encoderFor[U].namedExpressions, child) } @@ -83,6 +83,30 @@ case class MapPartitions( serializer: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode with ObjectOperator +object MapElements { + def apply[T : Encoder, U : Encoder]( + func: AnyRef, + child: LogicalPlan): MapElements = { + MapElements( + func, + UnresolvedDeserializer(encoderFor[T].deserializer), + encoderFor[U].namedExpressions, + child) + } +} + +/** + * A relation produced by applying `func` to each element of the `child`. + * + * @param deserializer used to extract the input to `func` from an input row. + * @param serializer use to serialize the output of `func`. + */ +case class MapElements( + func: AnyRef, + deserializer: Expression, + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectOperator + /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { def apply[T : Encoder, U : Encoder]( @@ -90,7 +114,7 @@ object AppendColumns { child: LogicalPlan): AppendColumns = { new AppendColumns( func.asInstanceOf[Any => Any], - UnresolvedDeserializer(encoderFor[T].deserializer, Nil), + UnresolvedDeserializer(encoderFor[T].deserializer), encoderFor[U].namedExpressions, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f472a5068e..2854d5f9da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -766,7 +766,8 @@ class Dataset[T] private[sql]( implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) - withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) => + + withTypedPlan { Project( leftData :: rightData :: Nil, joined.analyzed) @@ -1900,7 +1901,9 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) + def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { + MapElements[T, U](func, logicalPlan) + } /** * :: Experimental :: @@ -1911,8 +1914,10 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = - map(t => func.call(t))(encoder) + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + implicit val uEnc = encoder + withTypedPlan(MapElements[T, U](func, logicalPlan)) + } /** * :: Experimental :: @@ -2412,12 +2417,7 @@ class Dataset[T] private[sql]( } /** A convenient function to wrap a logical plan and produce a Dataset. */ - @inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = { - new Dataset[T](sqlContext, logicalPlan, encoder) + @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { + Dataset(sqlContext, logicalPlan) } - - private[sql] def withTypedPlan[R]( - other: Dataset[_], encoder: Encoder[R])( - f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] = - new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e52f05a5f4..5f3128d8e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -341,6 +341,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapPartitions(f, in, out, child) => execution.MapPartitions(f, in, out, planLater(child)) :: Nil + case logical.MapElements(f, in, out, child) => + execution.MapElements(f, in, out, planLater(child)) :: Nil case logical.AppendColumns(f, in, out, child) => execution.AppendColumns(f, in, out, planLater(child)) :: Nil case logical.MapGroups(f, key, in, out, grouping, data, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 9f539c4929..4e75a3a794 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -152,7 +152,7 @@ trait CodegenSupport extends SparkPlan { s""" | |/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */ - |${evaluated} + |$evaluated |${parent.doConsume(ctx, inputVars, rowVar)} """.stripMargin } @@ -169,20 +169,20 @@ trait CodegenSupport extends SparkPlan { /** * Returns source code to evaluate the variables for required attributes, and clear the code - * of evaluated variables, to prevent them to be evaluated twice.. + * of evaluated variables, to prevent them to be evaluated twice. */ protected def evaluateRequiredVariables( attributes: Seq[Attribute], variables: Seq[ExprCode], required: AttributeSet): String = { - var evaluateVars = "" + val evaluateVars = new StringBuilder variables.zipWithIndex.foreach { case (ev, i) => if (ev.code != "" && required.contains(attributes(i))) { - evaluateVars += ev.code.trim + "\n" + evaluateVars.append(ev.code.trim + "\n") ev.code = "" } } - evaluateVars + evaluateVars.toString() } /** @@ -305,7 +305,6 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup def doCodeGen(): (CodegenContext, String) = { val ctx = new CodegenContext val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) - val references = ctx.references.toArray val source = s""" public Object generate(Object[] references) { return new GeneratedIterator(references); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 582dda8603..f48f3f09c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution +import scala.language.existentials + +import org.apache.spark.api.java.function.MapFunction import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.ObjectType @@ -68,6 +71,70 @@ case class MapPartitions( } /** + * Applies the given function to each input row and encodes the result. + * + * Note that, each serializer expression needs the result object which is returned by the given + * function, as input. This operator uses some tricks to make sure we only calculate the result + * object once. We don't use [[Project]] directly as subexpression elimination doesn't work with + * whole stage codegen and it's confusing to show the un-common-subexpression-eliminated version of + * a project while explain. + */ +case class MapElements( + func: AnyRef, + deserializer: Expression, + serializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with ObjectOperator with CodegenSupport { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val (funcClass, methodName) = func match { + case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call" + case _ => classOf[Any => Any] -> "apply" + } + val funcObj = Literal.create(func, ObjectType(funcClass)) + val resultObjType = serializer.head.collect { case b: BoundReference => b }.head.dataType + val callFunc = Invoke(funcObj, methodName, resultObjType, Seq(deserializer)) + + val bound = ExpressionCanonicalizer.execute( + BindReferences.bindReference(callFunc, child.output)) + ctx.currentVars = input + val evaluated = bound.gen(ctx) + + val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, resultObjType) + val outputFields = serializer.map(_ transform { + case _: BoundReference => resultObj + }) + val resultVars = outputFields.map(_.gen(ctx)) + s""" + ${evaluated.code} + ${consume(ctx, resultVars)} + """ + } + + override protected def doExecute(): RDD[InternalRow] = { + val callFunc: Any => Any = func match { + case m: MapFunction[_, _] => i => m.asInstanceOf[MapFunction[Any, Any]].call(i) + case _ => func.asInstanceOf[Any => Any] + } + child.execute().mapPartitionsInternal { iter => + val getObject = generateToObject(deserializer, child.output) + val outputObject = generateToRow(serializer) + iter.map(row => outputObject(callFunc(getObject(row)))) + } + } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering +} + +/** * Applies the given function to each input row, appending the encoded result at the end of the row. */ case class AppendColumns( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala new file mode 100644 index 0000000000..6eb952445f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkContext +import org.apache.spark.sql.types.StringType +import org.apache.spark.util.Benchmark + +/** + * Benchmark for Dataset typed operations comparing with DataFrame and RDD versions. + */ +object DatasetBenchmark { + + case class Data(l: Long, s: String) + + def main(args: Array[String]): Unit = { + val sparkContext = new SparkContext("local[*]", "Dataset benchmark") + val sqlContext = new SQLContext(sparkContext) + + import sqlContext.implicits._ + + val numRows = 10000000 + val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val numChains = 10 + + val benchmark = new Benchmark("back-to-back map", numRows) + + val func = (d: Data) => Data(d.l + 1, d.s) + benchmark.addCase("Dataset") { iter => + var res = df.as[Data] + var i = 0 + while (i < numChains) { + res = res.map(func) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < numChains) { + res = res.select($"l" + 1 as "l") + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + val rdd = sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD") { iter => + var res = rdd + var i = 0 + while (i < numChains) { + res = rdd.map(func) + i += 1 + } + res.foreach(_ => Unit) + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Dataset 902 / 995 11.1 90.2 1.0X + DataFrame 132 / 167 75.5 13.2 6.8X + RDD 216 / 237 46.3 21.6 4.2X + */ + benchmark.run() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index f7f3bd78e9..4e62fac919 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -198,10 +198,7 @@ abstract class QueryTest extends PlanTest { val logicalPlan = df.queryExecution.analyzed // bypass some cases that we can't handle currently. logicalPlan.transform { - case _: MapPartitions => return - case _: MapGroups => return - case _: AppendColumns => return - case _: CoGroup => return + case _: ObjectOperator => return case _: LogicalRelation => return }.transformAllExpressions { case a: ImperativeAggregate => return diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 6d5be0b5dd..f73ca887f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import org.apache.spark.api.java.function.MapFunction +import org.apache.spark.sql.{Encoders, Row} import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions.{avg, broadcast, col, max} @@ -70,4 +71,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined) assert(df.collect() === Array(Row(1), Row(2), Row(3))) } + + test("MapElements should be included in WholeStageCodegen") { + import testImplicits._ + + val ds = sqlContext.range(10).map(_.toString) + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined) + assert(ds.collect() === 0.until(10).map(_.toString).toArray) + } } |