diff options
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 60 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 12 |
2 files changed, 62 insertions, 10 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 350b60134e..928ba213b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -102,7 +102,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) SimplifyCasts, SimplifyCaseConversionExpressions, RewriteCorrelatedScalarSubquery, - EliminateSerialization) :: + EliminateSerialization, + RemoveAliasOnlyProject) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :: Batch("Typed Filter Optimization", fixedPoint, @@ -156,6 +157,49 @@ object SamplePushDown extends Rule[LogicalPlan] { } /** + * Removes the Project only conducting Alias of its child node. + * It is created mainly for removing extra Project added in EliminateSerialization rule, + * but can also benefit other operators. + */ +object RemoveAliasOnlyProject extends Rule[LogicalPlan] { + // Check if projectList in the Project node has the same attribute names and ordering + // as its child node. + private def isAliasOnly( + projectList: Seq[NamedExpression], + childOutput: Seq[Attribute]): Boolean = { + if (!projectList.forall(_.isInstanceOf[Alias]) || projectList.length != childOutput.length) { + return false + } else { + projectList.map(_.asInstanceOf[Alias]).zip(childOutput).forall { case (a, o) => + a.child match { + case attr: Attribute if a.name == attr.name && attr.semanticEquals(o) => true + case _ => false + } + } + } + } + + def apply(plan: LogicalPlan): LogicalPlan = { + val aliasOnlyProject = plan.find { p => + p match { + case Project(pList, child) if isAliasOnly(pList, child.output) => true + case _ => false + } + } + + aliasOnlyProject.map { case p: Project => + val aliases = p.projectList.map(_.asInstanceOf[Alias]) + val attrMap = AttributeMap(aliases.map(a => (a.toAttribute, a.child))) + plan.transformAllExpressions { + case a: Attribute if attrMap.contains(a) => attrMap(a) + }.transform { + case op: Project if op.eq(p) => op.child + } + }.getOrElse(plan) + } +} + +/** * Removes cases where we are unnecessarily going between the object and serialized (InternalRow) * representation of data item. For example back to back map operations. */ @@ -163,15 +207,11 @@ object EliminateSerialization extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case d @ DeserializeToObject(_, _, s: SerializeFromObject) if d.outputObjectType == s.inputObjectType => - // A workaround for SPARK-14803. Remove this after it is fixed. - if (d.outputObjectType.isInstanceOf[ObjectType] && - d.outputObjectType.asInstanceOf[ObjectType].cls == classOf[org.apache.spark.sql.Row]) { - s.child - } else { - // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. - val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) - Project(objAttr :: Nil, s.child) - } + // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. + // We will remove it later in RemoveAliasOnlyProject rule. + val objAttr = + Alias(s.child.output.head, s.child.output.head.name)(exprId = d.output.head.exprId) + Project(objAttr :: Nil, s.child) case a @ AppendColumns(_, _, _, s: SerializeFromObject) if a.deserializer.dataType == s.inputObjectType => AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 0784041f34..3b9feae4a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -661,6 +661,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4) } + test("dataset.rdd with generic case class") { + val ds = Seq(Generic(1, 1.0), Generic(2, 2.0)).toDS + val ds2 = ds.map(g => Generic(g.id, g.value)) + assert(ds.rdd.map(r => r.id).count === 2) + assert(ds2.rdd.map(r => r.id).count === 2) + + val ds3 = ds.map(g => new java.lang.Long(g.id)) + assert(ds3.rdd.map(r => r).count === 2) + } + test("runtime null check for RowEncoder") { val schema = new StructType().add("i", IntegerType, nullable = false) val df = sqlContext.range(10).map(l => { @@ -694,6 +704,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } } +case class Generic[T](id: T, value: Double) + case class OtherTuple(_1: String, _2: Int) case class TupleClass(data: (Int, String)) |