aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala16
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala13
7 files changed, 60 insertions, 21 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index d629172a74..7abeb03296 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -52,6 +52,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
override lazy val resolved = false
override def newInstance = this
+ override def withNullability(newNullability: Boolean) = this
override def withQualifiers(newQualifiers: Seq[String]) = this
// Unresolved attributes are transient at compile time and don't get evaluated during execution.
@@ -95,6 +96,7 @@ case class Star(
override lazy val resolved = false
override def newInstance = this
+ override def withNullability(newNullability: Boolean) = this
override def withQualifiers(newQualifiers: Seq[String]) = this
def expand(input: Seq[Attribute]): Seq[NamedExpression] = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 655d4a08fe..9ce1f01056 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -33,14 +33,16 @@ case class BoundReference(ordinal: Int, baseReference: Attribute)
type EvaluatedType = Any
- def nullable = baseReference.nullable
- def dataType = baseReference.dataType
- def exprId = baseReference.exprId
- def qualifiers = baseReference.qualifiers
- def name = baseReference.name
+ override def nullable = baseReference.nullable
+ override def dataType = baseReference.dataType
+ override def exprId = baseReference.exprId
+ override def qualifiers = baseReference.qualifiers
+ override def name = baseReference.name
- def newInstance = BoundReference(ordinal, baseReference.newInstance)
- def withQualifiers(newQualifiers: Seq[String]) =
+ override def newInstance = BoundReference(ordinal, baseReference.newInstance)
+ override def withNullability(newNullability: Boolean) =
+ BoundReference(ordinal, baseReference.withNullability(newNullability))
+ override def withQualifiers(newQualifiers: Seq[String]) =
BoundReference(ordinal, baseReference.withQualifiers(newQualifiers))
override def toString = s"$baseReference:$ordinal"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 66ae22e95b..934bad8c27 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -57,6 +57,7 @@ abstract class NamedExpression extends Expression {
abstract class Attribute extends NamedExpression {
self: Product =>
+ def withNullability(newNullability: Boolean): Attribute
def withQualifiers(newQualifiers: Seq[String]): Attribute
def toAttribute = this
@@ -133,7 +134,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
/**
* Returns a copy of this [[AttributeReference]] with changed nullability.
*/
- def withNullability(newNullability: Boolean) = {
+ override def withNullability(newNullability: Boolean) = {
if (nullable == newNullability) {
this
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 3e0639867b..b51a02d5ac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{LeftSemi, JoinType}
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.types._
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
@@ -46,10 +46,16 @@ case class Generate(
child: LogicalPlan)
extends UnaryNode {
- protected def generatorOutput: Seq[Attribute] =
- alias
+ protected def generatorOutput: Seq[Attribute] = {
+ val output = alias
.map(a => generator.output.map(_.withQualifiers(a :: Nil)))
.getOrElse(generator.output)
+ if (join && outer) {
+ output.map(_.withNullability(true))
+ } else {
+ output
+ }
+ }
override def output =
if (join) child.output ++ generatorOutput else generatorOutput
@@ -81,11 +87,20 @@ case class Join(
condition: Option[Expression]) extends BinaryNode {
override def references = condition.map(_.references).getOrElse(Set.empty)
- override def output = joinType match {
- case LeftSemi =>
- left.output
- case _ =>
- left.output ++ right.output
+
+ override def output = {
+ joinType match {
+ case LeftSemi =>
+ left.output
+ case LeftOuter =>
+ left.output ++ right.output.map(_.withNullability(true))
+ case RightOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output
+ case FullOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+ case _ =>
+ left.output ++ right.output
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index d85d2d7844..c1ced8bfa4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -83,8 +83,8 @@ case class Aggregate(
case a: AggregateExpression =>
ComputedAggregate(
a,
- BindReferences.bindReference(a, childOutput).asInstanceOf[AggregateExpression],
- AttributeReference(s"aggResult:$a", a.dataType, nullable = true)())
+ BindReferences.bindReference(a, childOutput),
+ AttributeReference(s"aggResult:$a", a.dataType, a.nullable)())
}
}.toArray
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index da1e08be59..47b3d00262 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.catalyst.expressions.{Generator, JoinedRow, Literal, Projection}
+import org.apache.spark.sql.catalyst.expressions._
/**
* :: DeveloperApi ::
@@ -39,8 +39,16 @@ case class Generate(
child: SparkPlan)
extends UnaryNode {
+ protected def generatorOutput: Seq[Attribute] = {
+ if (join && outer) {
+ generator.output.map(_.withNullability(true))
+ } else {
+ generator.output
+ }
+ }
+
override def output =
- if (join) child.output ++ generator.output else generator.output
+ if (join) child.output ++ generatorOutput else generatorOutput
override def execute() = {
if (join) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 84bdde38b7..4797cd7adb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -271,7 +271,18 @@ case class BroadcastNestedLoopJoin(
override def otherCopyArgs = sqlContext :: Nil
- def output = left.output ++ right.output
+ override def output = {
+ joinType match {
+ case LeftOuter =>
+ left.output ++ right.output.map(_.withNullability(true))
+ case RightOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output
+ case FullOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+ case _ =>
+ left.output ++ right.output
+ }
+ }
/** The Streamed Relation */
def left = streamed