aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-12-18 10:09:17 -0800
committerDavies Liu <davies.liu@gmail.com>2015-12-18 10:09:17 -0800
commit4af647c77ded6a0d3087ceafb2e30e01d97e7a06 (patch)
tree893ad3f9d8de5b34a9ad46eba4862f898c68f044 /sql/catalyst
parentee444fe4b8c9f382524e1fa346c67ba6da8104d8 (diff)
downloadspark-4af647c77ded6a0d3087ceafb2e30e01d97e7a06.tar.gz
spark-4af647c77ded6a0d3087ceafb2e30e01d97e7a06.tar.bz2
spark-4af647c77ded6a0d3087ceafb2e30e01d97e7a06.zip
[SPARK-12054] [SQL] Consider nullability of expression in codegen
This could simplify the generated code for expressions that is not nullable. This PR fix lots of bugs about nullability. Author: Davies Liu <davies@databricks.com> Closes #10333 from davies/skip_nullable.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala28
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala95
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala65
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala9
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala10
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala1
21 files changed, 242 insertions, 129 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index ff1f28ddbb..7293d5d447 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -69,10 +69,17 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
- s"""
- boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
- $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
- """
+ if (nullable) {
+ s"""
+ boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
+ $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
+ """
+ } else {
+ ev.isNull = "false"
+ s"""
+ $javaType ${ev.value} = $value;
+ """
+ }
}
}
@@ -92,7 +99,7 @@ object BindReferences extends Logging {
sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
}
} else {
- BoundReference(ordinal, a.dataType, a.nullable)
+ BoundReference(ordinal, a.dataType, input(ordinal).nullable)
}
}
}.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index cb60d5958d..b18f49f320 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -87,18 +87,22 @@ object Cast {
private def resolvableNullability(from: Boolean, to: Boolean) = !from || to
private def forceNullable(from: DataType, to: DataType) = (from, to) match {
- case (StringType, _: NumericType) => true
- case (StringType, TimestampType) => true
- case (DoubleType, TimestampType) => true
- case (FloatType, TimestampType) => true
- case (StringType, DateType) => true
- case (_: NumericType, DateType) => true
- case (BooleanType, DateType) => true
- case (DateType, _: NumericType) => true
- case (DateType, BooleanType) => true
- case (DoubleType, _: DecimalType) => true
- case (FloatType, _: DecimalType) => true
- case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null
+ case (NullType, _) => true
+ case (_, _) if from == to => false
+
+ case (StringType, BinaryType) => false
+ case (StringType, _) => true
+ case (_, StringType) => false
+
+ case (FloatType | DoubleType, TimestampType) => true
+ case (TimestampType, DateType) => false
+ case (_, DateType) => true
+ case (DateType, TimestampType) => false
+ case (DateType, _) => true
+ case (_, CalendarIntervalType) => true
+
+ case (_, _: DecimalType) => true // overflow
+ case (_: FractionalType, _: IntegralType) => true // NaN, infinity
case _ => false
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 6d807c9ecf..6a9c12127d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -340,14 +340,21 @@ abstract class UnaryExpression extends Expression {
ev: GeneratedExpressionCode,
f: String => String): String = {
val eval = child.gen(ctx)
- val resultCode = f(eval.value)
- eval.code + s"""
- boolean ${ev.isNull} = ${eval.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- $resultCode
- }
- """
+ if (nullable) {
+ eval.code + s"""
+ boolean ${ev.isNull} = ${eval.isNull};
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${eval.isNull}) {
+ ${f(eval.value)}
+ }
+ """
+ } else {
+ ev.isNull = "false"
+ eval.code + s"""
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ ${f(eval.value)}
+ """
+ }
}
}
@@ -424,19 +431,30 @@ abstract class BinaryExpression extends Expression {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val resultCode = f(eval1.value, eval2.value)
- s"""
- ${eval1.code}
- boolean ${ev.isNull} = ${eval1.isNull};
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- ${eval2.code}
- if (!${eval2.isNull}) {
- $resultCode
- } else {
- ${ev.isNull} = true;
+ if (nullable) {
+ s"""
+ ${eval1.code}
+ boolean ${ev.isNull} = ${eval1.isNull};
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${eval2.code}
+ if (!${eval2.isNull}) {
+ $resultCode
+ } else {
+ ${ev.isNull} = true;
+ }
}
- }
- """
+ """
+
+ } else {
+ ev.isNull = "false"
+ s"""
+ ${eval1.code}
+ ${eval2.code}
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ $resultCode
+ """
+ }
}
}
@@ -548,20 +566,31 @@ abstract class TernaryExpression extends Expression {
f: (String, String, String) => String): String = {
val evals = children.map(_.gen(ctx))
val resultCode = f(evals(0).value, evals(1).value, evals(2).value)
- s"""
- ${evals(0).code}
- boolean ${ev.isNull} = true;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${evals(0).isNull}) {
- ${evals(1).code}
- if (!${evals(1).isNull}) {
- ${evals(2).code}
- if (!${evals(2).isNull}) {
- ${ev.isNull} = false; // resultCode could change nullability
- $resultCode
+ if (nullable) {
+ s"""
+ ${evals(0).code}
+ boolean ${ev.isNull} = true;
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${evals(0).isNull}) {
+ ${evals(1).code}
+ if (!${evals(1).isNull}) {
+ ${evals(2).code}
+ if (!${evals(2).isNull}) {
+ ${ev.isNull} = false; // resultCode could change nullability
+ $resultCode
+ }
}
}
- }
- """
+ """
+ } else {
+ ev.isNull = "false"
+ s"""
+ ${evals(0).code}
+ ${evals(1).code}
+ ${evals(2).code}
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ $resultCode
+ """
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index d07d4c338c..30f602227b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -53,7 +53,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w
override def children: Seq[Expression] = Seq(child)
- override def nullable: Boolean = false
+ override def nullable: Boolean = true
override def dataType: DataType = DoubleType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index 00d7436b71..d25f3335ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
/**
@@ -42,7 +41,7 @@ case class Corr(
override def children: Seq[Expression] = Seq(left, right)
- override def nullable: Boolean = false
+ override def nullable: Boolean = true
override def dataType: DataType = DoubleType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 441f52ab5c..663c69e799 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -31,7 +31,7 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)
- private lazy val count = AttributeReference("count", LongType)()
+ private lazy val count = AttributeReference("count", LongType, nullable = false)()
override lazy val aggBufferAttributes = count :: Nil
@@ -39,15 +39,24 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
/* count = */ Literal(0L)
)
- override lazy val updateExpressions = Seq(
- /* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L)
- )
+ override lazy val updateExpressions = {
+ val nullableChildren = children.filter(_.nullable)
+ if (nullableChildren.isEmpty) {
+ Seq(
+ /* count = */ count + 1L
+ )
+ } else {
+ Seq(
+ /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L)
+ )
+ }
+ }
override lazy val mergeExpressions = Seq(
/* count = */ count.left + count.right
)
- override lazy val evaluateExpression = Cast(count, LongType)
+ override lazy val evaluateExpression = count
override def defaultResult: Option[Literal] = Option(Literal(0L))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index cfb042e0aa..08a67ea3df 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -40,8 +40,6 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
private lazy val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
- // TODO: Remove this line once we remove the NullType from inputTypes.
- case NullType => IntegerType
case _ => child.dataType
}
@@ -57,18 +55,26 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
/* sum = */ Literal.create(null, sumDataType)
)
- override lazy val updateExpressions: Seq[Expression] = Seq(
- /* sum = */
- Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum))
- )
+ override lazy val updateExpressions: Seq[Expression] = {
+ if (child.nullable) {
+ Seq(
+ /* sum = */
+ Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum))
+ )
+ } else {
+ Seq(
+ /* sum = */
+ Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType))
+ )
+ }
+ }
override lazy val mergeExpressions: Seq[Expression] = {
- val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType))
Seq(
/* sum = */
- Coalesce(Seq(add, sum.left))
+ Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left))
)
}
- override lazy val evaluateExpression: Expression = Cast(sum, resultType)
+ override lazy val evaluateExpression: Expression = sum
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
index 26fb143d1e..80c5e41baa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
@@ -32,14 +32,23 @@ trait CodegenFallback extends Expression {
ctx.references += this
val objectTerm = ctx.freshName("obj")
- s"""
- /* expression: ${this.toCommentSafeString} */
- java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW});
- boolean ${ev.isNull} = $objectTerm == null;
- ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
- if (!${ev.isNull}) {
- ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
- }
- """
+ if (nullable) {
+ s"""
+ /* expression: ${this.toCommentSafeString} */
+ Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW});
+ boolean ${ev.isNull} = $objectTerm == null;
+ ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
+ if (!${ev.isNull}) {
+ ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
+ }
+ """
+ } else {
+ ev.isNull = "false"
+ s"""
+ /* expression: ${this.toCommentSafeString} */
+ Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW});
+ ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
+ """
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 40189f0877..a6ec242589 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -44,38 +44,55 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
- val isNull = s"isNull_$i"
- val value = s"value_$i"
- ctx.addMutableState("boolean", isNull, s"this.$isNull = true;")
- ctx.addMutableState(ctx.javaType(e.dataType), value,
- s"this.$value = ${ctx.defaultValue(e.dataType)};")
- s"""
- ${evaluationCode.code}
- this.$isNull = ${evaluationCode.isNull};
- this.$value = ${evaluationCode.value};
- """
+ if (e.nullable) {
+ val isNull = s"isNull_$i"
+ val value = s"value_$i"
+ ctx.addMutableState("boolean", isNull, s"this.$isNull = true;")
+ ctx.addMutableState(ctx.javaType(e.dataType), value,
+ s"this.$value = ${ctx.defaultValue(e.dataType)};")
+ s"""
+ ${evaluationCode.code}
+ this.$isNull = ${evaluationCode.isNull};
+ this.$value = ${evaluationCode.value};
+ """
+ } else {
+ val value = s"value_$i"
+ ctx.addMutableState(ctx.javaType(e.dataType), value,
+ s"this.$value = ${ctx.defaultValue(e.dataType)};")
+ s"""
+ ${evaluationCode.code}
+ this.$value = ${evaluationCode.value};
+ """
+ }
}
val updates = expressions.zipWithIndex.map {
case (NoOp, _) => ""
case (e, i) =>
- if (e.dataType.isInstanceOf[DecimalType]) {
- // Can't call setNullAt on DecimalType, because we need to keep the offset
- s"""
- if (this.isNull_$i) {
- ${ctx.setColumn("mutableRow", e.dataType, i, null)};
- } else {
- ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
- }
- """
+ if (e.nullable) {
+ if (e.dataType.isInstanceOf[DecimalType]) {
+ // Can't call setNullAt on DecimalType, because we need to keep the offset
+ s"""
+ if (this.isNull_$i) {
+ ${ctx.setColumn("mutableRow", e.dataType, i, null)};
+ } else {
+ ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
+ }
+ """
+ } else {
+ s"""
+ if (this.isNull_$i) {
+ mutableRow.setNullAt($i);
+ } else {
+ ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
+ }
+ """
+ }
} else {
s"""
- if (this.isNull_$i) {
- mutableRow.setNullAt($i);
- } else {
- ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
- }
+ ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
"""
}
+
}
val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 68005afb21..c1defe12b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -135,14 +135,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => s"$rowWriter.write($index, ${input.value});"
}
- s"""
- ${input.code}
- if (${input.isNull}) {
- ${setNull.trim}
- } else {
+ if (input.isNull == "false") {
+ s"""
+ ${input.code}
${writeField.trim}
- }
- """
+ """
+ } else {
+ s"""
+ ${input.code}
+ if (${input.isNull}) {
+ ${setNull.trim}
+ } else {
+ ${writeField.trim}
+ }
+ """
+ }
}
s"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 58f6a7ec8a..c5ed173eeb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -115,13 +115,19 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, eval => {
- s"""
- if ($eval.isNullAt($ordinal)) {
- ${ev.isNull} = true;
- } else {
+ if (nullable) {
+ s"""
+ if ($eval.isNullAt($ordinal)) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
+ }
+ """
+ } else {
+ s"""
${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
- }
- """
+ """
+ }
})
}
}
@@ -139,7 +145,6 @@ case class GetArrayStructFields(
containsNull: Boolean) extends UnaryExpression {
override def dataType: DataType = ArrayType(field.dataType, containsNull)
- override def nullable: Boolean = child.nullable || containsNull || field.nullable
override def toString: String = s"$child.${field.name}"
protected override def nullSafeEval(input: Any): Any = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 03c39f8404..311540e335 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -340,6 +340,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes {
Seq(TypeCollection(StringType, DateType, TimestampType), StringType)
override def dataType: DataType = LongType
+ override def nullable: Boolean = true
private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String]
@@ -455,6 +456,7 @@ case class FromUnixTime(sec: Expression, format: Expression)
}
override def dataType: DataType = StringType
+ override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType)
@@ -561,6 +563,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression)
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
override def dataType: DataType = DateType
+ override def nullable: Boolean = true
override def nullSafeEval(start: Any, dayOfW: Any): Any = {
val dow = DateTimeUtils.getDayOfWeekFromString(dayOfW.asInstanceOf[UTF8String])
@@ -832,6 +835,7 @@ case class TruncDate(date: Expression, format: Expression)
override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType)
override def dataType: DataType = DateType
+ override def nullable: Boolean = true
override def prettyName: String = "trunc"
private lazy val truncLevel: Int =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index 78f6631e46..c54bcdd774 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -47,6 +47,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression {
override def dataType: DataType = DecimalType(precision, scale)
+ override def nullable: Boolean = true
override def toString: String = s"MakeDecimal($child,$precision,$scale)"
protected override def nullSafeEval(input: Any): Any =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 4991b9cb54..72b323587c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -17,18 +17,19 @@
package org.apache.spark.sql.catalyst.expressions
-import java.io.{StringWriter, ByteArrayOutputStream}
+import java.io.{ByteArrayOutputStream, StringWriter}
+
+import scala.util.parsing.combinator.RegexParsers
import com.fasterxml.jackson.core._
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.types.{StructField, StructType, StringType, DataType}
+import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
-import scala.util.parsing.combinator.RegexParsers
-
private[this] sealed trait PathInstruction
private[this] object PathInstruction {
private[expressions] case object Subscript extends PathInstruction
@@ -108,15 +109,17 @@ private[this] object SharedFactory {
case class GetJsonObject(json: Expression, path: Expression)
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
- import SharedFactory._
+ import com.fasterxml.jackson.core.JsonToken._
+
import PathInstruction._
+ import SharedFactory._
import WriteStyle._
- import com.fasterxml.jackson.core.JsonToken._
override def left: Expression = json
override def right: Expression = path
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def dataType: DataType = StringType
+ override def nullable: Boolean = true
override def prettyName: String = "get_json_object"
@transient private lazy val parsedPath = parsePath(path.eval().asInstanceOf[UTF8String])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 28f616fbb9..9c1a3294de 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -75,6 +75,8 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String)
abstract class UnaryLogExpression(f: Double => Double, name: String)
extends UnaryMathExpression(f, name) {
+ override def nullable: Boolean = true
+
// values less than or equal to yAsymptote eval to null in Hive, instead of NaN or -Infinity
protected val yAsymptote: Double = 0.0
@@ -194,6 +196,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr)
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType)
override def dataType: DataType = StringType
+ override def nullable: Boolean = true
override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = {
NumberConverter.convert(
@@ -621,6 +624,8 @@ case class Logarithm(left: Expression, right: Expression)
this(EulerNumber(), child)
}
+ override def nullable: Boolean = true
+
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val dLeft = input1.asInstanceOf[Double]
val dRight = input2.asInstanceOf[Double]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 0f6d02f2e0..5baab4f7e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -57,6 +57,7 @@ case class Sha2(left: Expression, right: Expression)
extends BinaryExpression with Serializable with ImplicitCastInputTypes {
override def dataType: DataType = StringType
+ override def nullable: Boolean = true
override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 8770c4b76c..50c8b9d598 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -924,6 +924,7 @@ case class FormatNumber(x: Expression, d: Expression)
override def left: Expression = x
override def right: Expression = d
override def dataType: DataType = StringType
+ override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
// Associated with the pattern, for the last d value, and we will update the
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index 06252ac4fc..91f169e7ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -329,7 +329,7 @@ abstract class OffsetWindowFunction
*/
override def foldable: Boolean = input.foldable && (default == null || default.foldable)
- override def nullable: Boolean = input.nullable && (default == null || default.nullable)
+ override def nullable: Boolean = default == null || default.nullable
override lazy val frame = {
// This will be triggered by the Analyzer.
@@ -381,7 +381,7 @@ abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowF
self: Product =>
override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow)
override def dataType: DataType = IntegerType
- override def nullable: Boolean = false
+ override def nullable: Boolean = true
override def supportsPartial: Boolean = false
override lazy val mergeExpressions =
throw new UnsupportedOperationException("Window Functions do not support merging.")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 5665fd7e5f..ec42b763f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -293,7 +293,14 @@ private[sql] object Expand {
Literal.create(bitmask, IntegerType)
})
}
- Expand(projections, child.output :+ gid, child)
+ val output = child.output.map { attr =>
+ if (groupByExprs.exists(_.semanticEquals(attr))) {
+ attr.withNullability(true)
+ } else {
+ attr
+ }
+ }
+ Expand(projections, output :+ gid, child)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index a98e16c253..c99a4ac964 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -297,7 +297,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
test("cast from string") {
assert(cast("abcdef", StringType).nullable === false)
assert(cast("abcdef", BinaryType).nullable === false)
- assert(cast("abcdef", BooleanType).nullable === false)
+ assert(cast("abcdef", BooleanType).nullable === true)
assert(cast("abcdef", TimestampType).nullable === true)
assert(cast("abcdef", LongType).nullable === true)
assert(cast("abcdef", IntegerType).nullable === true)
@@ -547,7 +547,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}
{
val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false))
- assert(ret.resolved === true)
+ assert(ret.resolved === false)
checkEvaluation(ret, Seq(null, true, false))
}
@@ -606,7 +606,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}
{
val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false))
- assert(ret.resolved === true)
+ assert(ret.resolved === false)
checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false))
}
{
@@ -713,7 +713,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("a", BooleanType, nullable = true),
StructField("b", BooleanType, nullable = true),
StructField("c", BooleanType, nullable = false))))
- assert(ret.resolved === true)
+ assert(ret.resolved === false)
checkEvaluation(ret, InternalRow(null, true, false))
}
@@ -754,7 +754,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructType(Seq(
StructField("l", LongType, nullable = true)))))))
- assert(ret.resolved === true)
+ assert(ret.resolved === false)
checkEvaluation(ret, Row(
Seq(123, null, null),
Map("a" -> null, "b" -> true, "c" -> false),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 9f1b19253e..9c1688b261 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._