diff options
author | Takuya UESHIN <ueshin@happy-camper.st> | 2014-07-05 11:51:48 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-07-05 11:51:48 -0700 |
commit | 9d5ecf8205b924dc8a3c13fed68beb78cc5c7553 (patch) | |
tree | f6ec13b96cfa3ed5391a1538e2c8eb97713bea98 /sql | |
parent | 3da8df939ec63064692ba64d9188aeea908b305c (diff) | |
download | spark-9d5ecf8205b924dc8a3c13fed68beb78cc5c7553.tar.gz spark-9d5ecf8205b924dc8a3c13fed68beb78cc5c7553.tar.bz2 spark-9d5ecf8205b924dc8a3c13fed68beb78cc5c7553.zip |
[SPARK-2327] [SQL] Fix nullabilities of Join/Generate/Aggregate.
Fix nullabilities of `Join`/`Generate`/`Aggregate` because:
- Output attributes of opposite side of `OuterJoin` should be nullable.
- Output attributes of generater side of `Generate` should be nullable if `join` is `true` and `outer` is `true`.
- `AttributeReference` of `computedAggregates` of `Aggregate` should be the same as `aggregateExpression`'s.
Author: Takuya UESHIN <ueshin@happy-camper.st>
Closes #1266 from ueshin/issues/SPARK-2327 and squashes the following commits:
3ace83a [Takuya UESHIN] Add withNullability to Attribute and use it to change nullabilities.
df1ae53 [Takuya UESHIN] Modify nullabilize to leave attribute if not resolved.
799ce56 [Takuya UESHIN] Add nullabilization to Generate of SparkPlan.
a0fc9bc [Takuya UESHIN] Fix scalastyle errors.
0e31e37 [Takuya UESHIN] Fix Aggregate resultAttribute nullabilities.
09532ec [Takuya UESHIN] Fix Generate output nullabilities.
f20f196 [Takuya UESHIN] Fix Join output nullabilities.
Diffstat (limited to 'sql')
7 files changed, 60 insertions, 21 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index d629172a74..7abeb03296 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -52,6 +52,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo override lazy val resolved = false override def newInstance = this + override def withNullability(newNullability: Boolean) = this override def withQualifiers(newQualifiers: Seq[String]) = this // Unresolved attributes are transient at compile time and don't get evaluated during execution. @@ -95,6 +96,7 @@ case class Star( override lazy val resolved = false override def newInstance = this + override def withNullability(newNullability: Boolean) = this override def withQualifiers(newQualifiers: Seq[String]) = this def expand(input: Seq[Attribute]): Seq[NamedExpression] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 655d4a08fe..9ce1f01056 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -33,14 +33,16 @@ case class BoundReference(ordinal: Int, baseReference: Attribute) type EvaluatedType = Any - def nullable = baseReference.nullable - def dataType = baseReference.dataType - def exprId = baseReference.exprId - def qualifiers = baseReference.qualifiers - def name = baseReference.name + override def nullable = baseReference.nullable + override def dataType = baseReference.dataType + override def exprId = baseReference.exprId + override def qualifiers = baseReference.qualifiers + override def name = baseReference.name - def newInstance = BoundReference(ordinal, baseReference.newInstance) - def withQualifiers(newQualifiers: Seq[String]) = + override def newInstance = BoundReference(ordinal, baseReference.newInstance) + override def withNullability(newNullability: Boolean) = + BoundReference(ordinal, baseReference.withNullability(newNullability)) + override def withQualifiers(newQualifiers: Seq[String]) = BoundReference(ordinal, baseReference.withQualifiers(newQualifiers)) override def toString = s"$baseReference:$ordinal" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 66ae22e95b..934bad8c27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -57,6 +57,7 @@ abstract class NamedExpression extends Expression { abstract class Attribute extends NamedExpression { self: Product => + def withNullability(newNullability: Boolean): Attribute def withQualifiers(newQualifiers: Seq[String]): Attribute def toAttribute = this @@ -133,7 +134,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea /** * Returns a copy of this [[AttributeReference]] with changed nullability. */ - def withNullability(newNullability: Boolean) = { + override def withNullability(newNullability: Boolean) = { if (nullable == newNullability) { this } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index bac5a72464..0728fa73fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{LeftSemi, JoinType} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.types._ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { @@ -46,10 +46,16 @@ case class Generate( child: LogicalPlan) extends UnaryNode { - protected def generatorOutput: Seq[Attribute] = - alias + protected def generatorOutput: Seq[Attribute] = { + val output = alias .map(a => generator.output.map(_.withQualifiers(a :: Nil))) .getOrElse(generator.output) + if (join && outer) { + output.map(_.withNullability(true)) + } else { + output + } + } override def output = if (join) child.output ++ generatorOutput else generatorOutput @@ -81,11 +87,20 @@ case class Join( condition: Option[Expression]) extends BinaryNode { override def references = condition.map(_.references).getOrElse(Set.empty) - override def output = joinType match { - case LeftSemi => - left.output - case _ => - left.output ++ right.output + + override def output = { + joinType match { + case LeftSemi => + left.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case _ => + left.output ++ right.output + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index d85d2d7844..c1ced8bfa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -83,8 +83,8 @@ case class Aggregate( case a: AggregateExpression => ComputedAggregate( a, - BindReferences.bindReference(a, childOutput).asInstanceOf[AggregateExpression], - AttributeReference(s"aggResult:$a", a.dataType, nullable = true)()) + BindReferences.bindReference(a, childOutput), + AttributeReference(s"aggResult:$a", a.dataType, a.nullable)()) } }.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index da1e08be59..47b3d00262 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{Generator, JoinedRow, Literal, Projection} +import org.apache.spark.sql.catalyst.expressions._ /** * :: DeveloperApi :: @@ -39,8 +39,16 @@ case class Generate( child: SparkPlan) extends UnaryNode { + protected def generatorOutput: Seq[Attribute] = { + if (join && outer) { + generator.output.map(_.withNullability(true)) + } else { + generator.output + } + } + override def output = - if (join) child.output ++ generator.output else generator.output + if (join) child.output ++ generatorOutput else generatorOutput override def execute() = { if (join) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 32c5f26fe8..7d1f11caae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -319,7 +319,18 @@ case class BroadcastNestedLoopJoin( override def otherCopyArgs = sqlContext :: Nil - def output = left.output ++ right.output + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case _ => + left.output ++ right.output + } + } /** The Streamed Relation */ def left = streamed |