aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <simonh@tw.ibm.com>2016-05-12 10:11:12 -0700
committerYin Huai <yhuai@databricks.com>2016-05-12 10:11:12 -0700
commit470de743ecf3617babd86f50ab203e85aa975d69 (patch)
tree58a87637c705c3c0a8430a49607ec4c603e6dc22
parent5bb62b893bf13973de63ab28571e05501b84bfef (diff)
downloadspark-470de743ecf3617babd86f50ab203e85aa975d69.tar.gz
spark-470de743ecf3617babd86f50ab203e85aa975d69.tar.bz2
spark-470de743ecf3617babd86f50ab203e85aa975d69.zip
[SPARK-15094][SPARK-14803][SQL] Remove extra Project added in EliminateSerialization
## What changes were proposed in this pull request? We will eliminate the pair of `DeserializeToObject` and `SerializeFromObject` in `Optimizer` and add extra `Project`. However, when DeserializeToObject's outputObjectType is ObjectType and its cls can't be processed by unsafe project, it will be failed. To fix it, we can simply remove the extra `Project` and replace the output attribute of `DeserializeToObject` in another rule. ## How was this patch tested? `DatasetSuite`. Author: Liang-Chi Hsieh <simonh@tw.ibm.com> Closes #12926 from viirya/fix-eliminate-serialization-projection.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala60
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala12
2 files changed, 62 insertions, 10 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 350b60134e..928ba213b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -102,7 +102,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
SimplifyCasts,
SimplifyCaseConversionExpressions,
RewriteCorrelatedScalarSubquery,
- EliminateSerialization) ::
+ EliminateSerialization,
+ RemoveAliasOnlyProject) ::
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) ::
Batch("Typed Filter Optimization", fixedPoint,
@@ -156,6 +157,49 @@ object SamplePushDown extends Rule[LogicalPlan] {
}
/**
+ * Removes the Project only conducting Alias of its child node.
+ * It is created mainly for removing extra Project added in EliminateSerialization rule,
+ * but can also benefit other operators.
+ */
+object RemoveAliasOnlyProject extends Rule[LogicalPlan] {
+ // Check if projectList in the Project node has the same attribute names and ordering
+ // as its child node.
+ private def isAliasOnly(
+ projectList: Seq[NamedExpression],
+ childOutput: Seq[Attribute]): Boolean = {
+ if (!projectList.forall(_.isInstanceOf[Alias]) || projectList.length != childOutput.length) {
+ return false
+ } else {
+ projectList.map(_.asInstanceOf[Alias]).zip(childOutput).forall { case (a, o) =>
+ a.child match {
+ case attr: Attribute if a.name == attr.name && attr.semanticEquals(o) => true
+ case _ => false
+ }
+ }
+ }
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val aliasOnlyProject = plan.find { p =>
+ p match {
+ case Project(pList, child) if isAliasOnly(pList, child.output) => true
+ case _ => false
+ }
+ }
+
+ aliasOnlyProject.map { case p: Project =>
+ val aliases = p.projectList.map(_.asInstanceOf[Alias])
+ val attrMap = AttributeMap(aliases.map(a => (a.toAttribute, a.child)))
+ plan.transformAllExpressions {
+ case a: Attribute if attrMap.contains(a) => attrMap(a)
+ }.transform {
+ case op: Project if op.eq(p) => op.child
+ }
+ }.getOrElse(plan)
+ }
+}
+
+/**
* Removes cases where we are unnecessarily going between the object and serialized (InternalRow)
* representation of data item. For example back to back map operations.
*/
@@ -163,15 +207,11 @@ object EliminateSerialization extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case d @ DeserializeToObject(_, _, s: SerializeFromObject)
if d.outputObjectType == s.inputObjectType =>
- // A workaround for SPARK-14803. Remove this after it is fixed.
- if (d.outputObjectType.isInstanceOf[ObjectType] &&
- d.outputObjectType.asInstanceOf[ObjectType].cls == classOf[org.apache.spark.sql.Row]) {
- s.child
- } else {
- // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
- val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
- Project(objAttr :: Nil, s.child)
- }
+ // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
+ // We will remove it later in RemoveAliasOnlyProject rule.
+ val objAttr =
+ Alias(s.child.output.head, s.child.output.head.name)(exprId = d.output.head.exprId)
+ Project(objAttr :: Nil, s.child)
case a @ AppendColumns(_, _, _, s: SerializeFromObject)
if a.deserializer.dataType == s.inputObjectType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 0784041f34..3b9feae4a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -661,6 +661,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4)
}
+ test("dataset.rdd with generic case class") {
+ val ds = Seq(Generic(1, 1.0), Generic(2, 2.0)).toDS
+ val ds2 = ds.map(g => Generic(g.id, g.value))
+ assert(ds.rdd.map(r => r.id).count === 2)
+ assert(ds2.rdd.map(r => r.id).count === 2)
+
+ val ds3 = ds.map(g => new java.lang.Long(g.id))
+ assert(ds3.rdd.map(r => r).count === 2)
+ }
+
test("runtime null check for RowEncoder") {
val schema = new StructType().add("i", IntegerType, nullable = false)
val df = sqlContext.range(10).map(l => {
@@ -694,6 +704,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
}
}
+case class Generic[T](id: T, value: Double)
+
case class OtherTuple(_1: String, _2: Int)
case class TupleClass(data: (Int, String))