aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2016-01-14 17:44:56 -0800
committerMichael Armbrust <michael@databricks.com>2016-01-14 17:44:56 -0800
commitcc7af86afd3e769d1e2a581f31bb3db5a3d0229f (patch)
tree2fbd24829a347a0765a6882f98c87ac555aaa55b
parent25782981cf58946dc7c186acadd2beec5d964461 (diff)
downloadspark-cc7af86afd3e769d1e2a581f31bb3db5a3d0229f.tar.gz
spark-cc7af86afd3e769d1e2a581f31bb3db5a3d0229f.tar.bz2
spark-cc7af86afd3e769d1e2a581f31bb3db5a3d0229f.zip
[SPARK-12813][SQL] Eliminate serialization for back to back operations
The goal of this PR is to eliminate unnecessary translations when there are back-to-back `MapPartitions` operations. In order to achieve this I also made the following simplifications: - Operators no longer have hold encoders, instead they have only the expressions that they need. The benefits here are twofold: the expressions are visible to transformations so go through the normal resolution/binding process. now that they are visible we can change them on a case by case basis. - Operators no longer have type parameters. Since the engine is responsible for its own type checking, having the types visible to the complier was an unnecessary complication. We still leverage the scala compiler in the companion factory when constructing a new operator, but after this the types are discarded. Deferred to a follow up PR: - Remove as much of the resolution/binding from Dataset/GroupedDataset as possible. We should still eagerly check resolution and throw an error though in the case of mismatches for an `as` operation. - Eliminate serializations in more cases by adding more cases to `EliminateSerialization` Author: Michael Armbrust <michael@databricks.com> Closes #10747 from marmbrus/encoderExpressions.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala16
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala119
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala185
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala76
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala127
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala182
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala8
17 files changed, 518 insertions, 274 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8a33af8207..dadea6b54a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1214,6 +1214,10 @@ object CleanupAliases extends Rule[LogicalPlan] {
Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases),
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)
+ // Operators that operate on objects should only have expressions from encoders, which should
+ // never have extra aliases.
+ case o: ObjectOperator => o
+
case other =>
var stop = false
other transformExpressionsDown {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index fc0e87aa68..79eebbf9b1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -160,6 +160,7 @@ abstract class Star extends LeafExpression with NamedExpression {
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
+ override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
override lazy val resolved = false
def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression]
@@ -246,6 +247,8 @@ case class MultiAlias(child: Expression, names: Seq[String])
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
+ override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
+
override lazy val resolved = false
override def toString: String = s"$child AS $names"
@@ -259,6 +262,7 @@ case class MultiAlias(child: Expression, names: Seq[String])
* @param expressions Expressions to expand.
*/
case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with Unevaluable {
+ override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = expressions
override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")")
}
@@ -298,6 +302,7 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def name: String = throw new UnresolvedException(this, "name")
+ override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
override lazy val resolved = false
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 05f746e72b..64832dc114 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -207,6 +207,16 @@ case class ExpressionEncoder[T](
resolve(attrs, OuterScopes.outerScopes).bind(attrs)
}
+
+ /**
+ * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form
+ * of this object.
+ */
+ def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(toRowExpressions).map {
+ case (_, ne: NamedExpression) => ne.newInstance()
+ case (name, e) => Alias(e, name)()
+ }
+
/**
* Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
* toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 7293d5d447..c94b2c0e27 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.types._
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
extends LeafExpression with NamedExpression {
- override def toString: String = s"input[$ordinal, $dataType]"
+ override def toString: String = s"input[$ordinal, ${dataType.simpleString}]"
// Use special getter for primitive types (for UnsafeRow)
override def eval(input: InternalRow): Any = {
@@ -66,6 +66,8 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def exprId: ExprId = throw new UnsupportedOperationException
+ override def newInstance(): NamedExpression = this
+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index eee708cb02..b6d7a7f5e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -79,6 +79,9 @@ trait NamedExpression extends Expression {
/** Returns the metadata when an expression is a reference to another expression with metadata. */
def metadata: Metadata = Metadata.empty
+ /** Returns a copy of this expression with a new `exprId`. */
+ def newInstance(): NamedExpression
+
protected def typeSuffix =
if (resolved) {
dataType match {
@@ -144,6 +147,9 @@ case class Alias(child: Expression, name: String)(
}
}
+ def newInstance(): NamedExpression =
+ Alias(child, name)(qualifiers = qualifiers, explicitMetadata = explicitMetadata)
+
override def toAttribute: Attribute = {
if (resolved) {
AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index c0c3e6e891..8385f7e1da 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -172,6 +172,8 @@ case class Invoke(
$objNullCheck
"""
}
+
+ override def toString: String = s"$targetObject.$functionName"
}
object NewInstance {
@@ -253,6 +255,8 @@ case class NewInstance(
"""
}
}
+
+ override def toString: String = s"newInstance($cls)"
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 487431f892..cc3371c08f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -67,7 +67,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
RemoveDispensableExpressions,
SimplifyFilters,
SimplifyCasts,
- SimplifyCaseConversionExpressions) ::
+ SimplifyCaseConversionExpressions,
+ EliminateSerialization) ::
Batch("Decimal Optimizations", FixedPoint(100),
DecimalAggregates) ::
Batch("LocalRelation", FixedPoint(100),
@@ -97,6 +98,19 @@ object SamplePushDown extends Rule[LogicalPlan] {
}
/**
+ * Removes cases where we are unnecessarily going between the object and serialized (InternalRow)
+ * representation of data item. For example back to back map operations.
+ */
+object EliminateSerialization extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case m @ MapPartitions(_, input, _, child: ObjectOperator)
+ if !input.isInstanceOf[Attribute] && m.input.dataType == child.outputObject.dataType =>
+ val childWithoutSerialization = child.withObjectOutput
+ m.copy(input = childWithoutSerialization.output.head, child = childWithoutSerialization)
+ }
+}
+
+/**
* Pushes certain operations to both sides of a Union, Intersect or Except operator.
* Operations that are safe to pushdown are listed as follows.
* Union:
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 64957db6b4..2a1b1b131d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.sql.Encoder
-import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
@@ -480,120 +478,3 @@ case object OneRowRelation extends LeafNode {
*/
override def statistics: Statistics = Statistics(sizeInBytes = 1)
}
-
-/**
- * A relation produced by applying `func` to each partition of the `child`. tEncoder/uEncoder are
- * used respectively to decode/encode from the JVM object representation expected by `func.`
- */
-case class MapPartitions[T, U](
- func: Iterator[T] => Iterator[U],
- tEncoder: ExpressionEncoder[T],
- uEncoder: ExpressionEncoder[U],
- output: Seq[Attribute],
- child: LogicalPlan) extends UnaryNode {
- override def producedAttributes: AttributeSet = outputSet
-}
-
-/** Factory for constructing new `AppendColumn` nodes. */
-object AppendColumns {
- def apply[T, U : Encoder](
- func: T => U,
- tEncoder: ExpressionEncoder[T],
- child: LogicalPlan): AppendColumns[T, U] = {
- val attrs = encoderFor[U].schema.toAttributes
- new AppendColumns[T, U](func, tEncoder, encoderFor[U], attrs, child)
- }
-}
-
-/**
- * A relation produced by applying `func` to each partition of the `child`, concatenating the
- * resulting columns at the end of the input row. tEncoder/uEncoder are used respectively to
- * decode/encode from the JVM object representation expected by `func.`
- */
-case class AppendColumns[T, U](
- func: T => U,
- tEncoder: ExpressionEncoder[T],
- uEncoder: ExpressionEncoder[U],
- newColumns: Seq[Attribute],
- child: LogicalPlan) extends UnaryNode {
- override def output: Seq[Attribute] = child.output ++ newColumns
- override def producedAttributes: AttributeSet = AttributeSet(newColumns)
-}
-
-/** Factory for constructing new `MapGroups` nodes. */
-object MapGroups {
- def apply[K, T, U : Encoder](
- func: (K, Iterator[T]) => TraversableOnce[U],
- kEncoder: ExpressionEncoder[K],
- tEncoder: ExpressionEncoder[T],
- groupingAttributes: Seq[Attribute],
- child: LogicalPlan): MapGroups[K, T, U] = {
- new MapGroups(
- func,
- kEncoder,
- tEncoder,
- encoderFor[U],
- groupingAttributes,
- encoderFor[U].schema.toAttributes,
- child)
- }
-}
-
-/**
- * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`.
- * Func is invoked with an object representation of the grouping key an iterator containing the
- * object representation of all the rows with that key.
- */
-case class MapGroups[K, T, U](
- func: (K, Iterator[T]) => TraversableOnce[U],
- kEncoder: ExpressionEncoder[K],
- tEncoder: ExpressionEncoder[T],
- uEncoder: ExpressionEncoder[U],
- groupingAttributes: Seq[Attribute],
- output: Seq[Attribute],
- child: LogicalPlan) extends UnaryNode {
- override def producedAttributes: AttributeSet = outputSet
-}
-
-/** Factory for constructing new `CoGroup` nodes. */
-object CoGroup {
- def apply[Key, Left, Right, Result : Encoder](
- func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
- keyEnc: ExpressionEncoder[Key],
- leftEnc: ExpressionEncoder[Left],
- rightEnc: ExpressionEncoder[Right],
- leftGroup: Seq[Attribute],
- rightGroup: Seq[Attribute],
- left: LogicalPlan,
- right: LogicalPlan): CoGroup[Key, Left, Right, Result] = {
- CoGroup(
- func,
- keyEnc,
- leftEnc,
- rightEnc,
- encoderFor[Result],
- encoderFor[Result].schema.toAttributes,
- leftGroup,
- rightGroup,
- left,
- right)
- }
-}
-
-/**
- * A relation produced by applying `func` to each grouping key and associated values from left and
- * right children.
- */
-case class CoGroup[Key, Left, Right, Result](
- func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
- keyEnc: ExpressionEncoder[Key],
- leftEnc: ExpressionEncoder[Left],
- rightEnc: ExpressionEncoder[Right],
- resultEnc: ExpressionEncoder[Result],
- output: Seq[Attribute],
- leftGroup: Seq[Attribute],
- rightGroup: Seq[Attribute],
- left: LogicalPlan,
- right: LogicalPlan) extends BinaryNode {
- override def producedAttributes: AttributeSet = outputSet
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
new file mode 100644
index 0000000000..7603480527
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.ObjectType
+
+/**
+ * A trait for logical operators that apply user defined functions to domain objects.
+ */
+trait ObjectOperator extends LogicalPlan {
+
+ /** The serializer that is used to produce the output of this operator. */
+ def serializer: Seq[NamedExpression]
+
+ /**
+ * The object type that is produced by the user defined function. Note that the return type here
+ * is the same whether or not the operator is output serialized data.
+ */
+ def outputObject: NamedExpression =
+ Alias(serializer.head.collect { case b: BoundReference => b }.head, "obj")()
+
+ /**
+ * Returns a copy of this operator that will produce an object instead of an encoded row.
+ * Used in the optimizer when transforming plans to remove unneeded serialization.
+ */
+ def withObjectOutput: LogicalPlan = if (output.head.dataType.isInstanceOf[ObjectType]) {
+ this
+ } else {
+ withNewSerializer(outputObject)
+ }
+
+ /** Returns a copy of this operator with a different serializer. */
+ def withNewSerializer(newSerializer: NamedExpression): LogicalPlan = makeCopy {
+ productIterator.map {
+ case c if c == serializer => newSerializer :: Nil
+ case other: AnyRef => other
+ }.toArray
+ }
+}
+
+object MapPartitions {
+ def apply[T : Encoder, U : Encoder](
+ func: Iterator[T] => Iterator[U],
+ child: LogicalPlan): MapPartitions = {
+ MapPartitions(
+ func.asInstanceOf[Iterator[Any] => Iterator[Any]],
+ encoderFor[T].fromRowExpression,
+ encoderFor[U].namedExpressions,
+ child)
+ }
+}
+
+/**
+ * A relation produced by applying `func` to each partition of the `child`.
+ * @param input used to extract the input to `func` from an input row.
+ * @param serializer use to serialize the output of `func`.
+ */
+case class MapPartitions(
+ func: Iterator[Any] => Iterator[Any],
+ input: Expression,
+ serializer: Seq[NamedExpression],
+ child: LogicalPlan) extends UnaryNode with ObjectOperator {
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+}
+
+/** Factory for constructing new `AppendColumn` nodes. */
+object AppendColumns {
+ def apply[T : Encoder, U : Encoder](
+ func: T => U,
+ child: LogicalPlan): AppendColumns = {
+ new AppendColumns(
+ func.asInstanceOf[Any => Any],
+ encoderFor[T].fromRowExpression,
+ encoderFor[U].namedExpressions,
+ child)
+ }
+}
+
+/**
+ * A relation produced by applying `func` to each partition of the `child`, concatenating the
+ * resulting columns at the end of the input row.
+ * @param input used to extract the input to `func` from an input row.
+ * @param serializer use to serialize the output of `func`.
+ */
+case class AppendColumns(
+ func: Any => Any,
+ input: Expression,
+ serializer: Seq[NamedExpression],
+ child: LogicalPlan) extends UnaryNode with ObjectOperator {
+ override def output: Seq[Attribute] = child.output ++ newColumns
+ def newColumns: Seq[Attribute] = serializer.map(_.toAttribute)
+}
+
+/** Factory for constructing new `MapGroups` nodes. */
+object MapGroups {
+ def apply[K : Encoder, T : Encoder, U : Encoder](
+ func: (K, Iterator[T]) => TraversableOnce[U],
+ groupingAttributes: Seq[Attribute],
+ child: LogicalPlan): MapGroups = {
+ new MapGroups(
+ func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
+ encoderFor[K].fromRowExpression,
+ encoderFor[T].fromRowExpression,
+ encoderFor[U].namedExpressions,
+ groupingAttributes,
+ child)
+ }
+}
+
+/**
+ * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`.
+ * Func is invoked with an object representation of the grouping key an iterator containing the
+ * object representation of all the rows with that key.
+ * @param keyObject used to extract the key object for each group.
+ * @param input used to extract the items in the iterator from an input row.
+ * @param serializer use to serialize the output of `func`.
+ */
+case class MapGroups(
+ func: (Any, Iterator[Any]) => TraversableOnce[Any],
+ keyObject: Expression,
+ input: Expression,
+ serializer: Seq[NamedExpression],
+ groupingAttributes: Seq[Attribute],
+ child: LogicalPlan) extends UnaryNode with ObjectOperator {
+
+ def output: Seq[Attribute] = serializer.map(_.toAttribute)
+}
+
+/** Factory for constructing new `CoGroup` nodes. */
+object CoGroup {
+ def apply[Key : Encoder, Left : Encoder, Right : Encoder, Result : Encoder](
+ func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
+ leftGroup: Seq[Attribute],
+ rightGroup: Seq[Attribute],
+ left: LogicalPlan,
+ right: LogicalPlan): CoGroup = {
+ CoGroup(
+ func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]],
+ encoderFor[Key].fromRowExpression,
+ encoderFor[Left].fromRowExpression,
+ encoderFor[Right].fromRowExpression,
+ encoderFor[Result].namedExpressions,
+ leftGroup,
+ rightGroup,
+ left,
+ right)
+ }
+}
+
+/**
+ * A relation produced by applying `func` to each grouping key and associated values from left and
+ * right children.
+ */
+case class CoGroup(
+ func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any],
+ keyObject: Expression,
+ leftObject: Expression,
+ rightObject: Expression,
+ serializer: Seq[NamedExpression],
+ leftGroup: Seq[Attribute],
+ rightGroup: Seq[Attribute],
+ left: LogicalPlan,
+ right: LogicalPlan) extends BinaryNode with ObjectOperator {
+ override def producedAttributes: AttributeSet = outputSet
+
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala
new file mode 100644
index 0000000000..9177737560
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.NewInstance
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, MapPartitions}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+case class OtherTuple(_1: Int, _2: Int)
+
+class EliminateSerializationSuite extends PlanTest {
+ private object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Serialization", FixedPoint(100),
+ EliminateSerialization) :: Nil
+ }
+
+ implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()
+ private val func = identity[Iterator[(Int, Int)]] _
+ private val func2 = identity[Iterator[OtherTuple]] _
+
+ def assertObjectCreations(count: Int, plan: LogicalPlan): Unit = {
+ val newInstances = plan.flatMap(_.expressions.collect {
+ case n: NewInstance => n
+ })
+
+ if (newInstances.size != count) {
+ fail(
+ s"""
+ |Wrong number of object creations in plan: ${newInstances.size} != $count
+ |$plan
+ """.stripMargin)
+ }
+ }
+
+ test("back to back MapPartitions") {
+ val input = LocalRelation('_1.int, '_2.int)
+ val plan =
+ MapPartitions(func,
+ MapPartitions(func, input))
+
+ val optimized = Optimize.execute(plan.analyze)
+ assertObjectCreations(1, optimized)
+ }
+
+ test("back to back with object change") {
+ val input = LocalRelation('_1.int, '_2.int)
+ val plan =
+ MapPartitions(func,
+ MapPartitions(func2, input))
+
+ val optimized = Optimize.execute(plan.analyze)
+ assertObjectCreations(2, optimized)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 42f01e9359..9a9f7d111c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -336,12 +336,7 @@ class Dataset[T] private[sql](
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
new Dataset[U](
sqlContext,
- MapPartitions[T, U](
- func,
- resolvedTEncoder,
- encoderFor[U],
- encoderFor[U].schema.toAttributes,
- logicalPlan))
+ MapPartitions[T, U](func, logicalPlan))
}
/**
@@ -434,7 +429,7 @@ class Dataset[T] private[sql](
*/
def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
val inputPlan = logicalPlan
- val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan)
+ val withGroupingKey = AppendColumns(func, inputPlan)
val executed = sqlContext.executePlan(withGroupingKey)
new GroupedDataset(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index a819ddceb1..b3f8284364 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -115,8 +115,6 @@ class GroupedDataset[K, V] private[sql](
sqlContext,
MapGroups(
f,
- resolvedKEncoder,
- resolvedVEncoder,
groupingAttributes,
logicalPlan))
}
@@ -305,13 +303,11 @@ class GroupedDataset[K, V] private[sql](
def cogroup[U, R : Encoder](
other: GroupedDataset[K, U])(
f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+ implicit val uEncoder = other.unresolvedVEncoder
new Dataset[R](
sqlContext,
CoGroup(
f,
- this.resolvedKEncoder,
- this.resolvedVEncoder,
- other.resolvedVEncoder,
this.groupingAttributes,
other.groupingAttributes,
this.logicalPlan,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 482130a18d..910519d0e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -309,16 +309,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
throw new IllegalStateException(
"logical distinct operator should have been replaced by aggregate in the optimizer")
- case logical.MapPartitions(f, tEnc, uEnc, output, child) =>
- execution.MapPartitions(f, tEnc, uEnc, output, planLater(child)) :: Nil
- case logical.AppendColumns(f, tEnc, uEnc, newCol, child) =>
- execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil
- case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) =>
- execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil
- case logical.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output,
- leftGroup, rightGroup, left, right) =>
- execution.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, leftGroup, rightGroup,
- planLater(left), planLater(right)) :: Nil
+ case logical.MapPartitions(f, in, out, child) =>
+ execution.MapPartitions(f, in, out, planLater(child)) :: Nil
+ case logical.AppendColumns(f, in, out, child) =>
+ execution.AppendColumns(f, in, out, planLater(child)) :: Nil
+ case logical.MapGroups(f, key, in, out, grouping, child) =>
+ execution.MapGroups(f, key, in, out, grouping, planLater(child)) :: Nil
+ case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, left, right) =>
+ execution.CoGroup(
+ f, keyObj, lObj, rObj, out, lGroup, rGroup, planLater(left), planLater(right)) :: Nil
case logical.Repartition(numPartitions, shuffle, child) =>
if (shuffle) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 95bef68323..92c9a56131 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -21,9 +21,7 @@ import org.apache.spark.{HashPartitioner, SparkEnv}
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.LongType
@@ -329,128 +327,3 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl
protected override def doExecute(): RDD[InternalRow] = child.execute()
}
-
-/**
- * Applies the given function to each input row and encodes the result.
- */
-case class MapPartitions[T, U](
- func: Iterator[T] => Iterator[U],
- tEncoder: ExpressionEncoder[T],
- uEncoder: ExpressionEncoder[U],
- output: Seq[Attribute],
- child: SparkPlan) extends UnaryNode {
- override def producedAttributes: AttributeSet = outputSet
-
- override protected def doExecute(): RDD[InternalRow] = {
- child.execute().mapPartitionsInternal { iter =>
- val tBoundEncoder = tEncoder.bind(child.output)
- func(iter.map(tBoundEncoder.fromRow)).map(uEncoder.toRow)
- }
- }
-}
-
-/**
- * Applies the given function to each input row, appending the encoded result at the end of the row.
- */
-case class AppendColumns[T, U](
- func: T => U,
- tEncoder: ExpressionEncoder[T],
- uEncoder: ExpressionEncoder[U],
- newColumns: Seq[Attribute],
- child: SparkPlan) extends UnaryNode {
- override def producedAttributes: AttributeSet = AttributeSet(newColumns)
-
- override def output: Seq[Attribute] = child.output ++ newColumns
-
- override protected def doExecute(): RDD[InternalRow] = {
- child.execute().mapPartitionsInternal { iter =>
- val tBoundEncoder = tEncoder.bind(child.output)
- val combiner = GenerateUnsafeRowJoiner.create(tEncoder.schema, uEncoder.schema)
- iter.map { row =>
- val newColumns = uEncoder.toRow(func(tBoundEncoder.fromRow(row)))
- combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow
- }
- }
- }
-}
-
-/**
- * Groups the input rows together and calls the function with each group and an iterator containing
- * all elements in the group. The result of this function is encoded and flattened before
- * being output.
- */
-case class MapGroups[K, T, U](
- func: (K, Iterator[T]) => TraversableOnce[U],
- kEncoder: ExpressionEncoder[K],
- tEncoder: ExpressionEncoder[T],
- uEncoder: ExpressionEncoder[U],
- groupingAttributes: Seq[Attribute],
- output: Seq[Attribute],
- child: SparkPlan) extends UnaryNode {
- override def producedAttributes: AttributeSet = outputSet
-
- override def requiredChildDistribution: Seq[Distribution] =
- ClusteredDistribution(groupingAttributes) :: Nil
-
- override def requiredChildOrdering: Seq[Seq[SortOrder]] =
- Seq(groupingAttributes.map(SortOrder(_, Ascending)))
-
- override protected def doExecute(): RDD[InternalRow] = {
- child.execute().mapPartitionsInternal { iter =>
- val grouped = GroupedIterator(iter, groupingAttributes, child.output)
- val groupKeyEncoder = kEncoder.bind(groupingAttributes)
- val groupDataEncoder = tEncoder.bind(child.output)
-
- grouped.flatMap { case (key, rowIter) =>
- val result = func(
- groupKeyEncoder.fromRow(key),
- rowIter.map(groupDataEncoder.fromRow))
- result.map(uEncoder.toRow)
- }
- }
- }
-}
-
-/**
- * Co-groups the data from left and right children, and calls the function with each group and 2
- * iterators containing all elements in the group from left and right side.
- * The result of this function is encoded and flattened before being output.
- */
-case class CoGroup[Key, Left, Right, Result](
- func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
- keyEnc: ExpressionEncoder[Key],
- leftEnc: ExpressionEncoder[Left],
- rightEnc: ExpressionEncoder[Right],
- resultEnc: ExpressionEncoder[Result],
- output: Seq[Attribute],
- leftGroup: Seq[Attribute],
- rightGroup: Seq[Attribute],
- left: SparkPlan,
- right: SparkPlan) extends BinaryNode {
- override def producedAttributes: AttributeSet = outputSet
-
- override def requiredChildDistribution: Seq[Distribution] =
- ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
-
- override def requiredChildOrdering: Seq[Seq[SortOrder]] =
- leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil
-
- override protected def doExecute(): RDD[InternalRow] = {
- left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
- val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
- val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
- val boundKeyEnc = keyEnc.bind(leftGroup)
- val boundLeftEnc = leftEnc.bind(left.output)
- val boundRightEnc = rightEnc.bind(right.output)
-
- new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
- case (key, leftResult, rightResult) =>
- val result = func(
- boundKeyEnc.fromRow(key),
- leftResult.map(boundLeftEnc.fromRow),
- rightResult.map(boundRightEnc.fromRow))
- result.map(resultEnc.toRow)
- }
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
new file mode 100644
index 0000000000..2acca1743c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.types.ObjectType
+
+/**
+ * Helper functions for physical operators that work with user defined objects.
+ */
+trait ObjectOperator extends SparkPlan {
+ def generateToObject(objExpr: Expression, inputSchema: Seq[Attribute]): InternalRow => Any = {
+ val objectProjection = GenerateSafeProjection.generate(objExpr :: Nil, inputSchema)
+ (i: InternalRow) => objectProjection(i).get(0, objExpr.dataType)
+ }
+
+ def generateToRow(serializer: Seq[Expression]): Any => InternalRow = {
+ val outputProjection = if (serializer.head.dataType.isInstanceOf[ObjectType]) {
+ GenerateSafeProjection.generate(serializer)
+ } else {
+ GenerateUnsafeProjection.generate(serializer)
+ }
+ val inputType = serializer.head.collect { case b: BoundReference => b.dataType }.head
+ val outputRow = new SpecificMutableRow(inputType :: Nil)
+ (o: Any) => {
+ outputRow(0) = o
+ outputProjection(outputRow)
+ }
+ }
+}
+
+/**
+ * Applies the given function to each input row and encodes the result.
+ */
+case class MapPartitions(
+ func: Iterator[Any] => Iterator[Any],
+ input: Expression,
+ serializer: Seq[NamedExpression],
+ child: SparkPlan) extends UnaryNode with ObjectOperator {
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsInternal { iter =>
+ val getObject = generateToObject(input, child.output)
+ val outputObject = generateToRow(serializer)
+ func(iter.map(getObject)).map(outputObject)
+ }
+ }
+}
+
+/**
+ * Applies the given function to each input row, appending the encoded result at the end of the row.
+ */
+case class AppendColumns(
+ func: Any => Any,
+ input: Expression,
+ serializer: Seq[NamedExpression],
+ child: SparkPlan) extends UnaryNode with ObjectOperator {
+
+ override def output: Seq[Attribute] = child.output ++ serializer.map(_.toAttribute)
+
+ private def newColumnSchema = serializer.map(_.toAttribute).toStructType
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsInternal { iter =>
+ val getObject = generateToObject(input, child.output)
+ val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema)
+ val outputObject = generateToRow(serializer)
+
+ iter.map { row =>
+ val newColumns = outputObject(func(getObject(row)))
+
+ // This operates on the assumption that we always serialize the result...
+ combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow
+ }
+ }
+ }
+}
+
+/**
+ * Groups the input rows together and calls the function with each group and an iterator containing
+ * all elements in the group. The result of this function is encoded and flattened before
+ * being output.
+ */
+case class MapGroups(
+ func: (Any, Iterator[Any]) => TraversableOnce[Any],
+ keyObject: Expression,
+ input: Expression,
+ serializer: Seq[NamedExpression],
+ groupingAttributes: Seq[Attribute],
+ child: SparkPlan) extends UnaryNode with ObjectOperator {
+
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(groupingAttributes) :: Nil
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ Seq(groupingAttributes.map(SortOrder(_, Ascending)))
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsInternal { iter =>
+ val grouped = GroupedIterator(iter, groupingAttributes, child.output)
+
+ val getKey = generateToObject(keyObject, groupingAttributes)
+ val getValue = generateToObject(input, child.output)
+ val outputObject = generateToRow(serializer)
+
+ grouped.flatMap { case (key, rowIter) =>
+ val result = func(
+ getKey(key),
+ rowIter.map(getValue))
+ result.map(outputObject)
+ }
+ }
+ }
+}
+
+/**
+ * Co-groups the data from left and right children, and calls the function with each group and 2
+ * iterators containing all elements in the group from left and right side.
+ * The result of this function is encoded and flattened before being output.
+ */
+case class CoGroup(
+ func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any],
+ keyObject: Expression,
+ leftObject: Expression,
+ rightObject: Expression,
+ serializer: Seq[NamedExpression],
+ leftGroup: Seq[Attribute],
+ rightGroup: Seq[Attribute],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryNode with ObjectOperator {
+
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
+ val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
+ val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
+
+ val getKey = generateToObject(keyObject, leftGroup)
+ val getLeft = generateToObject(leftObject, left.output)
+ val getRight = generateToObject(rightObject, right.output)
+ val outputObject = generateToRow(serializer)
+
+ new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
+ case (key, leftResult, rightResult) =>
+ val result = func(
+ getKey(key),
+ leftResult.map(getLeft),
+ rightResult.map(getRight))
+ result.map(outputObject)
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index d7b86e3811..b69bb21db5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -26,6 +26,8 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+case class OtherTuple(_1: String, _2: Int)
+
class DatasetSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -111,6 +113,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
("a", 2), ("b", 3), ("c", 4))
}
+ test("map with type change") {
+ val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
+
+ checkAnswer(
+ ds.map(identity[(String, Int)])
+ .as[OtherTuple]
+ .map(identity[OtherTuple]),
+ OtherTuple("a", 1), OtherTuple("b", 2), OtherTuple("c", 3))
+ }
+
test("map and group by with class data") {
// We inject a group by here to make sure this test case is future proof
// when we implement better pipelining and local execution mode.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index fac26bd0c0..ce12f788b7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -192,10 +192,10 @@ abstract class QueryTest extends PlanTest {
val logicalPlan = df.queryExecution.analyzed
// bypass some cases that we can't handle currently.
logicalPlan.transform {
- case _: MapPartitions[_, _] => return
- case _: MapGroups[_, _, _] => return
- case _: AppendColumns[_, _] => return
- case _: CoGroup[_, _, _, _] => return
+ case _: MapPartitions => return
+ case _: MapGroups => return
+ case _: AppendColumns => return
+ case _: CoGroup => return
case _: LogicalRelation => return
}.transformAllExpressions {
case a: ImperativeAggregate => return