aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-12-18 10:09:17 -0800
committerDavies Liu <davies.liu@gmail.com>2015-12-18 10:09:17 -0800
commit4af647c77ded6a0d3087ceafb2e30e01d97e7a06 (patch)
tree893ad3f9d8de5b34a9ad46eba4862f898c68f044
parentee444fe4b8c9f382524e1fa346c67ba6da8104d8 (diff)
downloadspark-4af647c77ded6a0d3087ceafb2e30e01d97e7a06.tar.gz
spark-4af647c77ded6a0d3087ceafb2e30e01d97e7a06.tar.bz2
spark-4af647c77ded6a0d3087ceafb2e30e01d97e7a06.zip
[SPARK-12054] [SQL] Consider nullability of expression in codegen
This could simplify the generated code for expressions that is not nullable. This PR fix lots of bugs about nullability. Author: Davies Liu <davies@databricks.com> Closes #10333 from davies/skip_nullable.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala28
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala95
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala65
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala9
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala10
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala80
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala2
27 files changed, 261 insertions, 226 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index ff1f28ddbb..7293d5d447 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -69,10 +69,17 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
- s"""
- boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
- $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
- """
+ if (nullable) {
+ s"""
+ boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
+ $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
+ """
+ } else {
+ ev.isNull = "false"
+ s"""
+ $javaType ${ev.value} = $value;
+ """
+ }
}
}
@@ -92,7 +99,7 @@ object BindReferences extends Logging {
sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
}
} else {
- BoundReference(ordinal, a.dataType, a.nullable)
+ BoundReference(ordinal, a.dataType, input(ordinal).nullable)
}
}
}.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index cb60d5958d..b18f49f320 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -87,18 +87,22 @@ object Cast {
private def resolvableNullability(from: Boolean, to: Boolean) = !from || to
private def forceNullable(from: DataType, to: DataType) = (from, to) match {
- case (StringType, _: NumericType) => true
- case (StringType, TimestampType) => true
- case (DoubleType, TimestampType) => true
- case (FloatType, TimestampType) => true
- case (StringType, DateType) => true
- case (_: NumericType, DateType) => true
- case (BooleanType, DateType) => true
- case (DateType, _: NumericType) => true
- case (DateType, BooleanType) => true
- case (DoubleType, _: DecimalType) => true
- case (FloatType, _: DecimalType) => true
- case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null
+ case (NullType, _) => true
+ case (_, _) if from == to => false
+
+ case (StringType, BinaryType) => false
+ case (StringType, _) => true
+ case (_, StringType) => false
+
+ case (FloatType | DoubleType, TimestampType) => true
+ case (TimestampType, DateType) => false
+ case (_, DateType) => true
+ case (DateType, TimestampType) => false
+ case (DateType, _) => true
+ case (_, CalendarIntervalType) => true
+
+ case (_, _: DecimalType) => true // overflow
+ case (_: FractionalType, _: IntegralType) => true // NaN, infinity
case _ => false
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 6d807c9ecf..6a9c12127d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -340,14 +340,21 @@ abstract class UnaryExpression extends Expression {
ev: GeneratedExpressionCode,
f: String => String): String = {
val eval = child.gen(ctx)
- val resultCode = f(eval.value)
- eval.code + s"""
- boolean ${ev.isNull} = ${eval.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- $resultCode
- }
- """
+ if (nullable) {
+ eval.code + s"""
+ boolean ${ev.isNull} = ${eval.isNull};
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${eval.isNull}) {
+ ${f(eval.value)}
+ }
+ """
+ } else {
+ ev.isNull = "false"
+ eval.code + s"""
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${f(eval.value)}
+ """
+ }
}
}
@@ -424,19 +431,30 @@ abstract class BinaryExpression extends Expression {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val resultCode = f(eval1.value, eval2.value)
- s"""
- ${eval1.code}
- boolean ${ev.isNull} = ${eval1.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- ${eval2.code}
- if (!${eval2.isNull}) {
- $resultCode
- } else {
- ${ev.isNull} = true;
+ if (nullable) {
+ s"""
+ ${eval1.code}
+ boolean ${ev.isNull} = ${eval1.isNull};
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${eval2.code}
+ if (!${eval2.isNull}) {
+ $resultCode
+ } else {
+ ${ev.isNull} = true;
+ }
}
- }
- """
+ """
+
+ } else {
+ ev.isNull = "false"
+ s"""
+ ${eval1.code}
+ ${eval2.code}
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ $resultCode
+ """
+ }
}
}
@@ -548,20 +566,31 @@ abstract class TernaryExpression extends Expression {
f: (String, String, String) => String): String = {
val evals = children.map(_.gen(ctx))
val resultCode = f(evals(0).value, evals(1).value, evals(2).value)
- s"""
- ${evals(0).code}
- boolean ${ev.isNull} = true;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${evals(0).isNull}) {
- ${evals(1).code}
- if (!${evals(1).isNull}) {
- ${evals(2).code}
- if (!${evals(2).isNull}) {
- ${ev.isNull} = false; // resultCode could change nullability
- $resultCode
+ if (nullable) {
+ s"""
+ ${evals(0).code}
+ boolean ${ev.isNull} = true;
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${evals(0).isNull}) {
+ ${evals(1).code}
+ if (!${evals(1).isNull}) {
+ ${evals(2).code}
+ if (!${evals(2).isNull}) {
+ ${ev.isNull} = false; // resultCode could change nullability
+ $resultCode
+ }
}
}
- }
- """
+ """
+ } else {
+ ev.isNull = "false"
+ s"""
+ ${evals(0).code}
+ ${evals(1).code}
+ ${evals(2).code}
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ $resultCode
+ """
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index d07d4c338c..30f602227b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -53,7 +53,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w
override def children: Seq[Expression] = Seq(child)
- override def nullable: Boolean = false
+ override def nullable: Boolean = true
override def dataType: DataType = DoubleType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index 00d7436b71..d25f3335ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
/**
@@ -42,7 +41,7 @@ case class Corr(
override def children: Seq[Expression] = Seq(left, right)
- override def nullable: Boolean = false
+ override def nullable: Boolean = true
override def dataType: DataType = DoubleType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 441f52ab5c..663c69e799 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -31,7 +31,7 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)
- private lazy val count = AttributeReference("count", LongType)()
+ private lazy val count = AttributeReference("count", LongType, nullable = false)()
override lazy val aggBufferAttributes = count :: Nil
@@ -39,15 +39,24 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
/* count = */ Literal(0L)
)
- override lazy val updateExpressions = Seq(
- /* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L)
- )
+ override lazy val updateExpressions = {
+ val nullableChildren = children.filter(_.nullable)
+ if (nullableChildren.isEmpty) {
+ Seq(
+ /* count = */ count + 1L
+ )
+ } else {
+ Seq(
+ /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L)
+ )
+ }
+ }
override lazy val mergeExpressions = Seq(
/* count = */ count.left + count.right
)
- override lazy val evaluateExpression = Cast(count, LongType)
+ override lazy val evaluateExpression = count
override def defaultResult: Option[Literal] = Option(Literal(0L))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index cfb042e0aa..08a67ea3df 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -40,8 +40,6 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
private lazy val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
- // TODO: Remove this line once we remove the NullType from inputTypes.
- case NullType => IntegerType
case _ => child.dataType
}
@@ -57,18 +55,26 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
/* sum = */ Literal.create(null, sumDataType)
)
- override lazy val updateExpressions: Seq[Expression] = Seq(
- /* sum = */
- Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum))
- )
+ override lazy val updateExpressions: Seq[Expression] = {
+ if (child.nullable) {
+ Seq(
+ /* sum = */
+ Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum))
+ )
+ } else {
+ Seq(
+ /* sum = */
+ Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType))
+ )
+ }
+ }
override lazy val mergeExpressions: Seq[Expression] = {
- val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType))
Seq(
/* sum = */
- Coalesce(Seq(add, sum.left))
+ Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left))
)
}
- override lazy val evaluateExpression: Expression = Cast(sum, resultType)
+ override lazy val evaluateExpression: Expression = sum
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
index 26fb143d1e..80c5e41baa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
@@ -32,14 +32,23 @@ trait CodegenFallback extends Expression {
ctx.references += this
val objectTerm = ctx.freshName("obj")
- s"""
- /* expression: ${this.toCommentSafeString} */
- java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW});
- boolean ${ev.isNull} = $objectTerm == null;
- ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
- if (!${ev.isNull}) {
- ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
- }
- """
+ if (nullable) {
+ s"""
+ /* expression: ${this.toCommentSafeString} */
+ Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW});
+ boolean ${ev.isNull} = $objectTerm == null;
+ ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
+ if (!${ev.isNull}) {
+ ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
+ }
+ """
+ } else {
+ ev.isNull = "false"
+ s"""
+ /* expression: ${this.toCommentSafeString} */
+ Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW});
+ ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
+ """
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 40189f0877..a6ec242589 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -44,38 +44,55 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
- val isNull = s"isNull_$i"
- val value = s"value_$i"
- ctx.addMutableState("boolean", isNull, s"this.$isNull = true;")
- ctx.addMutableState(ctx.javaType(e.dataType), value,
- s"this.$value = ${ctx.defaultValue(e.dataType)};")
- s"""
- ${evaluationCode.code}
- this.$isNull = ${evaluationCode.isNull};
- this.$value = ${evaluationCode.value};
- """
+ if (e.nullable) {
+ val isNull = s"isNull_$i"
+ val value = s"value_$i"
+ ctx.addMutableState("boolean", isNull, s"this.$isNull = true;")
+ ctx.addMutableState(ctx.javaType(e.dataType), value,
+ s"this.$value = ${ctx.defaultValue(e.dataType)};")
+ s"""
+ ${evaluationCode.code}
+ this.$isNull = ${evaluationCode.isNull};
+ this.$value = ${evaluationCode.value};
+ """
+ } else {
+ val value = s"value_$i"
+ ctx.addMutableState(ctx.javaType(e.dataType), value,
+ s"this.$value = ${ctx.defaultValue(e.dataType)};")
+ s"""
+ ${evaluationCode.code}
+ this.$value = ${evaluationCode.value};
+ """
+ }
}
val updates = expressions.zipWithIndex.map {
case (NoOp, _) => ""
case (e, i) =>
- if (e.dataType.isInstanceOf[DecimalType]) {
- // Can't call setNullAt on DecimalType, because we need to keep the offset
- s"""
- if (this.isNull_$i) {
- ${ctx.setColumn("mutableRow", e.dataType, i, null)};
- } else {
- ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
- }
- """
+ if (e.nullable) {
+ if (e.dataType.isInstanceOf[DecimalType]) {
+ // Can't call setNullAt on DecimalType, because we need to keep the offset
+ s"""
+ if (this.isNull_$i) {
+ ${ctx.setColumn("mutableRow", e.dataType, i, null)};
+ } else {
+ ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
+ }
+ """
+ } else {
+ s"""
+ if (this.isNull_$i) {
+ mutableRow.setNullAt($i);
+ } else {
+ ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
+ }
+ """
+ }
} else {
s"""
- if (this.isNull_$i) {
- mutableRow.setNullAt($i);
- } else {
- ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
- }
+ ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
"""
}
+
}
val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 68005afb21..c1defe12b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -135,14 +135,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => s"$rowWriter.write($index, ${input.value});"
}
- s"""
- ${input.code}
- if (${input.isNull}) {
- ${setNull.trim}
- } else {
+ if (input.isNull == "false") {
+ s"""
+ ${input.code}
${writeField.trim}
- }
- """
+ """
+ } else {
+ s"""
+ ${input.code}
+ if (${input.isNull}) {
+ ${setNull.trim}
+ } else {
+ ${writeField.trim}
+ }
+ """
+ }
}
s"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 58f6a7ec8a..c5ed173eeb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -115,13 +115,19 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, eval => {
- s"""
- if ($eval.isNullAt($ordinal)) {
- ${ev.isNull} = true;
- } else {
+ if (nullable) {
+ s"""
+ if ($eval.isNullAt($ordinal)) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
+ }
+ """
+ } else {
+ s"""
${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
- }
- """
+ """
+ }
})
}
}
@@ -139,7 +145,6 @@ case class GetArrayStructFields(
containsNull: Boolean) extends UnaryExpression {
override def dataType: DataType = ArrayType(field.dataType, containsNull)
- override def nullable: Boolean = child.nullable || containsNull || field.nullable
override def toString: String = s"$child.${field.name}"
protected override def nullSafeEval(input: Any): Any = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 03c39f8404..311540e335 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -340,6 +340,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes {
Seq(TypeCollection(StringType, DateType, TimestampType), StringType)
override def dataType: DataType = LongType
+ override def nullable: Boolean = true
private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String]
@@ -455,6 +456,7 @@ case class FromUnixTime(sec: Expression, format: Expression)
}
override def dataType: DataType = StringType
+ override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType)
@@ -561,6 +563,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression)
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
override def dataType: DataType = DateType
+ override def nullable: Boolean = true
override def nullSafeEval(start: Any, dayOfW: Any): Any = {
val dow = DateTimeUtils.getDayOfWeekFromString(dayOfW.asInstanceOf[UTF8String])
@@ -832,6 +835,7 @@ case class TruncDate(date: Expression, format: Expression)
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
override def dataType: DataType = DateType
+ override def nullable: Boolean = true
override def prettyName: String = "trunc"
private lazy val truncLevel: Int =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index 78f6631e46..c54bcdd774 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -47,6 +47,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression {
override def dataType: DataType = DecimalType(precision, scale)
+ override def nullable: Boolean = true
override def toString: String = s"MakeDecimal($child,$precision,$scale)"
protected override def nullSafeEval(input: Any): Any =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 4991b9cb54..72b323587c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -17,18 +17,19 @@
package org.apache.spark.sql.catalyst.expressions
-import java.io.{StringWriter, ByteArrayOutputStream}
+import java.io.{ByteArrayOutputStream, StringWriter}
+
+import scala.util.parsing.combinator.RegexParsers
import com.fasterxml.jackson.core._
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.types.{StructField, StructType, StringType, DataType}
+import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
-import scala.util.parsing.combinator.RegexParsers
-
private[this] sealed trait PathInstruction
private[this] object PathInstruction {
private[expressions] case object Subscript extends PathInstruction
@@ -108,15 +109,17 @@ private[this] object SharedFactory {
case class GetJsonObject(json: Expression, path: Expression)
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
- import SharedFactory._
+ import com.fasterxml.jackson.core.JsonToken._
+
import PathInstruction._
+ import SharedFactory._
import WriteStyle._
- import com.fasterxml.jackson.core.JsonToken._
override def left: Expression = json
override def right: Expression = path
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def dataType: DataType = StringType
+ override def nullable: Boolean = true
override def prettyName: String = "get_json_object"
@transient private lazy val parsedPath = parsePath(path.eval().asInstanceOf[UTF8String])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 28f616fbb9..9c1a3294de 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -75,6 +75,8 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String)
abstract class UnaryLogExpression(f: Double => Double, name: String)
extends UnaryMathExpression(f, name) {
+ override def nullable: Boolean = true
+
// values less than or equal to yAsymptote eval to null in Hive, instead of NaN or -Infinity
protected val yAsymptote: Double = 0.0
@@ -194,6 +196,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr)
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType)
override def dataType: DataType = StringType
+ override def nullable: Boolean = true
override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = {
NumberConverter.convert(
@@ -621,6 +624,8 @@ case class Logarithm(left: Expression, right: Expression)
this(EulerNumber(), child)
}
+ override def nullable: Boolean = true
+
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val dLeft = input1.asInstanceOf[Double]
val dRight = input2.asInstanceOf[Double]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 0f6d02f2e0..5baab4f7e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -57,6 +57,7 @@ case class Sha2(left: Expression, right: Expression)
extends BinaryExpression with Serializable with ImplicitCastInputTypes {
override def dataType: DataType = StringType
+ override def nullable: Boolean = true
override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 8770c4b76c..50c8b9d598 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -924,6 +924,7 @@ case class FormatNumber(x: Expression, d: Expression)
override def left: Expression = x
override def right: Expression = d
override def dataType: DataType = StringType
+ override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
// Associated with the pattern, for the last d value, and we will update the
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index 06252ac4fc..91f169e7ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -329,7 +329,7 @@ abstract class OffsetWindowFunction
*/
override def foldable: Boolean = input.foldable && (default == null || default.foldable)
- override def nullable: Boolean = input.nullable && (default == null || default.nullable)
+ override def nullable: Boolean = default == null || default.nullable
override lazy val frame = {
// This will be triggered by the Analyzer.
@@ -381,7 +381,7 @@ abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowF
self: Product =>
override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow)
override def dataType: DataType = IntegerType
- override def nullable: Boolean = false
+ override def nullable: Boolean = true
override def supportsPartial: Boolean = false
override lazy val mergeExpressions =
throw new UnsupportedOperationException("Window Functions do not support merging.")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 5665fd7e5f..ec42b763f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -293,7 +293,14 @@ private[sql] object Expand {
Literal.create(bitmask, IntegerType)
})
}
- Expand(projections, child.output :+ gid, child)
+ val output = child.output.map { attr =>
+ if (groupByExprs.exists(_.semanticEquals(attr))) {
+ attr.withNullability(true)
+ } else {
+ attr
+ }
+ }
+ Expand(projections, output :+ gid, child)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index a98e16c253..c99a4ac964 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -297,7 +297,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
test("cast from string") {
assert(cast("abcdef", StringType).nullable === false)
assert(cast("abcdef", BinaryType).nullable === false)
- assert(cast("abcdef", BooleanType).nullable === false)
+ assert(cast("abcdef", BooleanType).nullable === true)
assert(cast("abcdef", TimestampType).nullable === true)
assert(cast("abcdef", LongType).nullable === true)
assert(cast("abcdef", IntegerType).nullable === true)
@@ -547,7 +547,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}
{
val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false))
- assert(ret.resolved === true)
+ assert(ret.resolved === false)
checkEvaluation(ret, Seq(null, true, false))
}
@@ -606,7 +606,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}
{
val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false))
- assert(ret.resolved === true)
+ assert(ret.resolved === false)
checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false))
}
{
@@ -713,7 +713,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("a", BooleanType, nullable = true),
StructField("b", BooleanType, nullable = true),
StructField("c", BooleanType, nullable = false))))
- assert(ret.resolved === true)
+ assert(ret.resolved === false)
checkEvaluation(ret, InternalRow(null, true, false))
}
@@ -754,7 +754,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructType(Seq(
StructField("l", LongType, nullable = true)))))))
- assert(ret.resolved === true)
+ assert(ret.resolved === false)
checkEvaluation(ret, Row(
Seq(123, null, null),
Map("a" -> null, "b" -> true, "c" -> false),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 9f1b19253e..9c1688b261 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index d741312314..965eaa9efe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -25,22 +25,20 @@ import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import com.fasterxml.jackson.core.JsonFactory
-import org.apache.commons.lang3.StringUtils
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.PythonRDD
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
-import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution}
-import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection, SqlParser}
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
+import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
+import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution}
import org.apache.spark.sql.sources.HadoopFsRelation
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
@@ -455,7 +453,8 @@ class DataFrame private[sql](
// Analyze the self join. The assumption is that the analyzer will disambiguate left vs right
// by creating a new instance for one of the branch.
val joined = sqlContext.executePlan(
- Join(logicalPlan, right.logicalPlan, JoinType(joinType), None)).analyzed.asInstanceOf[Join]
+ Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None))
+ .analyzed.asInstanceOf[Join]
val condition = usingColumns.map { col =>
catalyst.expressions.EqualTo(
@@ -473,15 +472,15 @@ class DataFrame private[sql](
usingColumns.map(col => withPlan(joined.right).resolve(col))
case FullOuter =>
usingColumns.map { col =>
- val leftCol = withPlan(joined.left).resolve(col)
- val rightCol = withPlan(joined.right).resolve(col)
+ val leftCol = withPlan(joined.left).resolve(col).toAttribute.withNullability(true)
+ val rightCol = withPlan(joined.right).resolve(col).toAttribute.withNullability(true)
Alias(Coalesce(Seq(leftCol, rightCol)), col)()
}
}
// The nullability of output of joined could be different than original column,
// so we can only compare them by exprId
- val joinRefs = condition.map(_.references.toSeq.map(_.exprId)).getOrElse(Nil)
- val resultCols = joinedCols ++ joined.output.filterNot(e => joinRefs.contains(e.exprId))
+ val joinRefs = AttributeSet(condition.toSeq.flatMap(_.references))
+ val resultCols = joinedCols ++ joined.output.filterNot(joinRefs.contains(_))
withPlan {
Project(
resultCols,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 9852b6e7be..c941d673c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -440,16 +440,17 @@ private[execution] final class OffsetWindowFunctionFrame(
/** Create the projection. */
private[this] val projection = {
// Collect the expressions and bind them.
- val numInputAttributes = inputSchema.size
+ val inputAttrs = inputSchema.map(_.withNullability(true))
+ val numInputAttributes = inputAttrs.size
val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map {
case e: OffsetWindowFunction =>
- val input = BindReferences.bindReference(e.input, inputSchema)
+ val input = BindReferences.bindReference(e.input, inputAttrs)
if (e.default == null || e.default.foldable && e.default.eval() == null) {
// Without default value.
input
} else {
// With default value.
- val default = BindReferences.bindReference(e.default, inputSchema).transform {
+ val default = BindReferences.bindReference(e.default, inputAttrs).transform {
// Shift the input reference to its default version.
case BoundReference(o, dataType, nullable) =>
BoundReference(o + numInputAttributes, dataType, nullable)
@@ -457,7 +458,7 @@ private[execution] final class OffsetWindowFunctionFrame(
org.apache.spark.sql.catalyst.expressions.Coalesce(input :: default :: Nil)
}
case e =>
- BindReferences.bindReference(e, inputSchema)
+ BindReferences.bindReference(e, inputAttrs)
}
// Create the projection.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 24a79f289a..e2dc13d66c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -232,7 +232,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm
case class ExplainCommand(
logicalPlan: LogicalPlan,
override val output: Seq[Attribute] =
- Seq(AttributeReference("plan", StringType, nullable = false)()),
+ Seq(AttributeReference("plan", StringType, nullable = true)()),
extended: Boolean = false)
extends RunnableCommand {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
index e7deeff13d..e759c011e7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
@@ -42,7 +42,7 @@ case class DescribeCommand(
new MetadataBuilder().putString("comment", "name of the column").build())(),
AttributeReference("data_type", StringType, nullable = false,
new MetadataBuilder().putString("comment", "data type of the column").build())(),
- AttributeReference("comment", StringType, nullable = false,
+ AttributeReference("comment", StringType, nullable = true,
new MetadataBuilder().putString("comment", "comment of the column").build())()
)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index ed626fef56..c6e5868187 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -75,7 +75,7 @@ trait HashOuterJoin {
UnsafeProjection.create(streamedKeys, streamedPlan.output)
protected[this] def resultProjection: InternalRow => InternalRow =
- UnsafeProjection.create(self.schema)
+ UnsafeProjection.create(output, output)
@transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null)
@transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
@@ -151,82 +151,4 @@ trait HashOuterJoin {
}
ret.iterator
}
-
- protected[this] def fullOuterIterator(
- key: InternalRow,
- leftIter: Iterable[InternalRow],
- rightIter: Iterable[InternalRow],
- joinedRow: JoinedRow,
- resultProjection: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- if (!key.anyNull) {
- // Store the positions of records in right, if one of its associated row satisfy
- // the join condition.
- val rightMatchedSet = scala.collection.mutable.Set[Int]()
- leftIter.iterator.flatMap[InternalRow] { l =>
- joinedRow.withLeft(l)
- var matched = false
- rightIter.zipWithIndex.collect {
- // 1. For those matched (satisfy the join condition) records with both sides filled,
- // append them directly
-
- case (r, idx) if boundCondition(joinedRow.withRight(r)) =>
- numOutputRows += 1
- matched = true
- // if the row satisfy the join condition, add its index into the matched set
- rightMatchedSet.add(idx)
- resultProjection(joinedRow)
-
- } ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
- // 2. For those unmatched records in left, append additional records with empty right.
-
- // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
- // as we don't know whether we need to append it until finish iterating all
- // of the records in right side.
- // If we didn't get any proper row, then append a single row with empty right.
- numOutputRows += 1
- resultProjection(joinedRow.withRight(rightNullRow))
- })
- } ++ rightIter.zipWithIndex.collect {
- // 3. For those unmatched records in right, append additional records with empty left.
-
- // Re-visiting the records in right, and append additional row with empty left, if its not
- // in the matched set.
- case (r, idx) if !rightMatchedSet.contains(idx) =>
- numOutputRows += 1
- resultProjection(joinedRow(leftNullRow, r))
- }
- } else {
- leftIter.iterator.map[InternalRow] { l =>
- numOutputRows += 1
- resultProjection(joinedRow(l, rightNullRow))
- } ++ rightIter.iterator.map[InternalRow] { r =>
- numOutputRows += 1
- resultProjection(joinedRow(leftNullRow, r))
- }
- }
- }
-
- // This is only used by FullOuter
- protected[this] def buildHashTable(
- iter: Iterator[InternalRow],
- numIterRows: LongSQLMetric,
- keyGenerator: Projection): java.util.HashMap[InternalRow, CompactBuffer[InternalRow]] = {
- val hashTable = new java.util.HashMap[InternalRow, CompactBuffer[InternalRow]]()
- while (iter.hasNext) {
- val currentRow = iter.next()
- numIterRows += 1
- val rowKey = keyGenerator(currentRow)
-
- var existingMatchList = hashTable.get(rowKey)
- if (existingMatchList == null) {
- existingMatchList = new CompactBuffer[InternalRow]()
- hashTable.put(rowKey.copy(), existingMatchList)
- }
-
- existingMatchList += currentRow.copy()
- }
-
- hashTable
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
index efaa69c1d3..7ce38ebdb3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
@@ -114,7 +114,7 @@ case class SortMergeOuterJoin(
(r: InternalRow) => true
}
}
- val resultProj: InternalRow => InternalRow = UnsafeProjection.create(schema)
+ val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
joinType match {
case LeftOuter =>