From 4af647c77ded6a0d3087ceafb2e30e01d97e7a06 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 18 Dec 2015 10:09:17 -0800 Subject: [SPARK-12054] [SQL] Consider nullability of expression in codegen This could simplify the generated code for expressions that is not nullable. This PR fix lots of bugs about nullability. Author: Davies Liu Closes #10333 from davies/skip_nullable. --- .../sql/catalyst/expressions/BoundAttribute.scala | 17 ++-- .../spark/sql/catalyst/expressions/Cast.scala | 28 ++++--- .../sql/catalyst/expressions/Expression.scala | 95 ++++++++++++++-------- .../expressions/aggregate/CentralMomentAgg.scala | 2 +- .../sql/catalyst/expressions/aggregate/Corr.scala | 3 +- .../sql/catalyst/expressions/aggregate/Count.scala | 19 +++-- .../sql/catalyst/expressions/aggregate/Sum.scala | 24 ++++-- .../expressions/codegen/CodegenFallback.scala | 27 ++++-- .../codegen/GenerateMutableProjection.scala | 65 +++++++++------ .../codegen/GenerateUnsafeProjection.scala | 21 +++-- .../expressions/complexTypeExtractors.scala | 19 +++-- .../catalyst/expressions/datetimeExpressions.scala | 4 + .../catalyst/expressions/decimalExpressions.scala | 1 + .../sql/catalyst/expressions/jsonExpressions.scala | 15 ++-- .../sql/catalyst/expressions/mathExpressions.scala | 5 ++ .../spark/sql/catalyst/expressions/misc.scala | 1 + .../catalyst/expressions/stringExpressions.scala | 1 + .../catalyst/expressions/windowExpressions.scala | 4 +- .../catalyst/plans/logical/basicOperators.scala | 9 +- .../spark/sql/catalyst/expressions/CastSuite.scala | 10 +-- .../catalyst/expressions/ComplexTypeSuite.scala | 1 - 21 files changed, 242 insertions(+), 129 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index ff1f28ddbb..7293d5d447 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -69,10 +69,17 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) - s""" - boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); - """ + if (nullable) { + s""" + boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); + $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + """ + } else { + ev.isNull = "false" + s""" + $javaType ${ev.value} = $value; + """ + } } } @@ -92,7 +99,7 @@ object BindReferences extends Logging { sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") } } else { - BoundReference(ordinal, a.dataType, a.nullable) + BoundReference(ordinal, a.dataType, input(ordinal).nullable) } } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index cb60d5958d..b18f49f320 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -87,18 +87,22 @@ object Cast { private def resolvableNullability(from: Boolean, to: Boolean) = !from || to private def forceNullable(from: DataType, to: DataType) = (from, to) match { - case (StringType, _: NumericType) => true - case (StringType, TimestampType) => true - case (DoubleType, TimestampType) => true - case (FloatType, TimestampType) => true - case (StringType, DateType) => true - case (_: NumericType, DateType) => true - case (BooleanType, DateType) => true - case (DateType, _: NumericType) => true - case (DateType, BooleanType) => true - case (DoubleType, _: DecimalType) => true - case (FloatType, _: DecimalType) => true - case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null + case (NullType, _) => true + case (_, _) if from == to => false + + case (StringType, BinaryType) => false + case (StringType, _) => true + case (_, StringType) => false + + case (FloatType | DoubleType, TimestampType) => true + case (TimestampType, DateType) => false + case (_, DateType) => true + case (DateType, TimestampType) => false + case (DateType, _) => true + case (_, CalendarIntervalType) => true + + case (_, _: DecimalType) => true // overflow + case (_: FractionalType, _: IntegralType) => true // NaN, infinity case _ => false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6d807c9ecf..6a9c12127d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -340,14 +340,21 @@ abstract class UnaryExpression extends Expression { ev: GeneratedExpressionCode, f: String => String): String = { val eval = child.gen(ctx) - val resultCode = f(eval.value) - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - $resultCode - } - """ + if (nullable) { + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${eval.isNull}) { + ${f(eval.value)} + } + """ + } else { + ev.isNull = "false" + eval.code + s""" + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${f(eval.value)} + """ + } } } @@ -424,19 +431,30 @@ abstract class BinaryExpression extends Expression { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val resultCode = f(eval1.value, eval2.value) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; + if (nullable) { + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + $resultCode + } else { + ${ev.isNull} = true; + } } - } - """ + """ + + } else { + ev.isNull = "false" + s""" + ${eval1.code} + ${eval2.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $resultCode + """ + } } } @@ -548,20 +566,31 @@ abstract class TernaryExpression extends Expression { f: (String, String, String) => String): String = { val evals = children.map(_.gen(ctx)) val resultCode = f(evals(0).value, evals(1).value, evals(2).value) - s""" - ${evals(0).code} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${evals(0).isNull}) { - ${evals(1).code} - if (!${evals(1).isNull}) { - ${evals(2).code} - if (!${evals(2).isNull}) { - ${ev.isNull} = false; // resultCode could change nullability - $resultCode + if (nullable) { + s""" + ${evals(0).code} + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${evals(0).isNull}) { + ${evals(1).code} + if (!${evals(1).isNull}) { + ${evals(2).code} + if (!${evals(2).isNull}) { + ${ev.isNull} = false; // resultCode could change nullability + $resultCode + } } } - } - """ + """ + } else { + ev.isNull = "false" + s""" + ${evals(0).code} + ${evals(1).code} + ${evals(2).code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $resultCode + """ + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index d07d4c338c..30f602227b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -53,7 +53,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w override def children: Seq[Expression] = Seq(child) - override def nullable: Boolean = false + override def nullable: Boolean = true override def dataType: DataType = DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 00d7436b71..d25f3335ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -42,7 +41,7 @@ case class Corr( override def children: Seq[Expression] = Seq(left, right) - override def nullable: Boolean = false + override def nullable: Boolean = true override def dataType: DataType = DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 441f52ab5c..663c69e799 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -31,7 +31,7 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType) - private lazy val count = AttributeReference("count", LongType)() + private lazy val count = AttributeReference("count", LongType, nullable = false)() override lazy val aggBufferAttributes = count :: Nil @@ -39,15 +39,24 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { /* count = */ Literal(0L) ) - override lazy val updateExpressions = Seq( - /* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L) - ) + override lazy val updateExpressions = { + val nullableChildren = children.filter(_.nullable) + if (nullableChildren.isEmpty) { + Seq( + /* count = */ count + 1L + ) + } else { + Seq( + /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L) + ) + } + } override lazy val mergeExpressions = Seq( /* count = */ count.left + count.right ) - override lazy val evaluateExpression = Cast(count, LongType) + override lazy val evaluateExpression = count override def defaultResult: Option[Literal] = Option(Literal(0L)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index cfb042e0aa..08a67ea3df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -40,8 +40,6 @@ case class Sum(child: Expression) extends DeclarativeAggregate { private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) - // TODO: Remove this line once we remove the NullType from inputTypes. - case NullType => IntegerType case _ => child.dataType } @@ -57,18 +55,26 @@ case class Sum(child: Expression) extends DeclarativeAggregate { /* sum = */ Literal.create(null, sumDataType) ) - override lazy val updateExpressions: Seq[Expression] = Seq( - /* sum = */ - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) - ) + override lazy val updateExpressions: Seq[Expression] = { + if (child.nullable) { + Seq( + /* sum = */ + Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) + ) + } else { + Seq( + /* sum = */ + Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)) + ) + } + } override lazy val mergeExpressions: Seq[Expression] = { - val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType)) Seq( /* sum = */ - Coalesce(Seq(add, sum.left)) + Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left)) ) } - override lazy val evaluateExpression: Expression = Cast(sum, resultType) + override lazy val evaluateExpression: Expression = sum } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 26fb143d1e..80c5e41baa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -32,14 +32,23 @@ trait CodegenFallback extends Expression { ctx.references += this val objectTerm = ctx.freshName("obj") - s""" - /* expression: ${this.toCommentSafeString} */ - java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); - boolean ${ev.isNull} = $objectTerm == null; - ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; - if (!${ev.isNull}) { - ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; - } - """ + if (nullable) { + s""" + /* expression: ${this.toCommentSafeString} */ + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); + boolean ${ev.isNull} = $objectTerm == null; + ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; + if (!${ev.isNull}) { + ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + } + """ + } else { + ev.isNull = "false" + s""" + /* expression: ${this.toCommentSafeString} */ + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); + ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + """ + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 40189f0877..a6ec242589 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -44,38 +44,55 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) - val isNull = s"isNull_$i" - val value = s"value_$i" - ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") - ctx.addMutableState(ctx.javaType(e.dataType), value, - s"this.$value = ${ctx.defaultValue(e.dataType)};") - s""" - ${evaluationCode.code} - this.$isNull = ${evaluationCode.isNull}; - this.$value = ${evaluationCode.value}; - """ + if (e.nullable) { + val isNull = s"isNull_$i" + val value = s"value_$i" + ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") + ctx.addMutableState(ctx.javaType(e.dataType), value, + s"this.$value = ${ctx.defaultValue(e.dataType)};") + s""" + ${evaluationCode.code} + this.$isNull = ${evaluationCode.isNull}; + this.$value = ${evaluationCode.value}; + """ + } else { + val value = s"value_$i" + ctx.addMutableState(ctx.javaType(e.dataType), value, + s"this.$value = ${ctx.defaultValue(e.dataType)};") + s""" + ${evaluationCode.code} + this.$value = ${evaluationCode.value}; + """ + } } val updates = expressions.zipWithIndex.map { case (NoOp, _) => "" case (e, i) => - if (e.dataType.isInstanceOf[DecimalType]) { - // Can't call setNullAt on DecimalType, because we need to keep the offset - s""" - if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, null)}; - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } - """ + if (e.nullable) { + if (e.dataType.isInstanceOf[DecimalType]) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + s""" + if (this.isNull_$i) { + ${ctx.setColumn("mutableRow", e.dataType, i, null)}; + } else { + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; + } + """ + } else { + s""" + if (this.isNull_$i) { + mutableRow.setNullAt($i); + } else { + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; + } + """ + } } else { s""" - if (this.isNull_$i) { - mutableRow.setNullAt($i); - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; """ } + } val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 68005afb21..c1defe12b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -135,14 +135,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$rowWriter.write($index, ${input.value});" } - s""" - ${input.code} - if (${input.isNull}) { - ${setNull.trim} - } else { + if (input.isNull == "false") { + s""" + ${input.code} ${writeField.trim} - } - """ + """ + } else { + s""" + ${input.code} + if (${input.isNull}) { + ${setNull.trim} + } else { + ${writeField.trim} + } + """ + } } s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 58f6a7ec8a..c5ed173eeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -115,13 +115,19 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { - s""" - if ($eval.isNullAt($ordinal)) { - ${ev.isNull} = true; - } else { + if (nullable) { + s""" + if ($eval.isNullAt($ordinal)) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + } + """ + } else { + s""" ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; - } - """ + """ + } }) } } @@ -139,7 +145,6 @@ case class GetArrayStructFields( containsNull: Boolean) extends UnaryExpression { override def dataType: DataType = ArrayType(field.dataType, containsNull) - override def nullable: Boolean = child.nullable || containsNull || field.nullable override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 03c39f8404..311540e335 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -340,6 +340,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { Seq(TypeCollection(StringType, DateType, TimestampType), StringType) override def dataType: DataType = LongType + override def nullable: Boolean = true private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] @@ -455,6 +456,7 @@ case class FromUnixTime(sec: Expression, format: Expression) } override def dataType: DataType = StringType + override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) @@ -561,6 +563,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) override def dataType: DataType = DateType + override def nullable: Boolean = true override def nullSafeEval(start: Any, dayOfW: Any): Any = { val dow = DateTimeUtils.getDayOfWeekFromString(dayOfW.asInstanceOf[UTF8String]) @@ -832,6 +835,7 @@ case class TruncDate(date: Expression, format: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) override def dataType: DataType = DateType + override def nullable: Boolean = true override def prettyName: String = "trunc" private lazy val truncLevel: Int = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 78f6631e46..c54bcdd774 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -47,6 +47,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { override def dataType: DataType = DecimalType(precision, scale) + override def nullable: Boolean = true override def toString: String = s"MakeDecimal($child,$precision,$scale)" protected override def nullSafeEval(input: Any): Any = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 4991b9cb54..72b323587c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,18 +17,19 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{StringWriter, ByteArrayOutputStream} +import java.io.{ByteArrayOutputStream, StringWriter} + +import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{StructField, StructType, StringType, DataType} +import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -import scala.util.parsing.combinator.RegexParsers - private[this] sealed trait PathInstruction private[this] object PathInstruction { private[expressions] case object Subscript extends PathInstruction @@ -108,15 +109,17 @@ private[this] object SharedFactory { case class GetJsonObject(json: Expression, path: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback { - import SharedFactory._ + import com.fasterxml.jackson.core.JsonToken._ + import PathInstruction._ + import SharedFactory._ import WriteStyle._ - import com.fasterxml.jackson.core.JsonToken._ override def left: Expression = json override def right: Expression = path override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def dataType: DataType = StringType + override def nullable: Boolean = true override def prettyName: String = "get_json_object" @transient private lazy val parsedPath = parsePath(path.eval().asInstanceOf[UTF8String]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 28f616fbb9..9c1a3294de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -75,6 +75,8 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) abstract class UnaryLogExpression(f: Double => Double, name: String) extends UnaryMathExpression(f, name) { + override def nullable: Boolean = true + // values less than or equal to yAsymptote eval to null in Hive, instead of NaN or -Infinity protected val yAsymptote: Double = 0.0 @@ -194,6 +196,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) override def dataType: DataType = StringType + override def nullable: Boolean = true override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = { NumberConverter.convert( @@ -621,6 +624,8 @@ case class Logarithm(left: Expression, right: Expression) this(EulerNumber(), child) } + override def nullable: Boolean = true + protected override def nullSafeEval(input1: Any, input2: Any): Any = { val dLeft = input1.asInstanceOf[Double] val dRight = input2.asInstanceOf[Double] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 0f6d02f2e0..5baab4f7e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -57,6 +57,7 @@ case class Sha2(left: Expression, right: Expression) extends BinaryExpression with Serializable with ImplicitCastInputTypes { override def dataType: DataType = StringType + override def nullable: Boolean = true override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 8770c4b76c..50c8b9d598 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -924,6 +924,7 @@ case class FormatNumber(x: Expression, d: Expression) override def left: Expression = x override def right: Expression = d override def dataType: DataType = StringType + override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) // Associated with the pattern, for the last d value, and we will update the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 06252ac4fc..91f169e7ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -329,7 +329,7 @@ abstract class OffsetWindowFunction */ override def foldable: Boolean = input.foldable && (default == null || default.foldable) - override def nullable: Boolean = input.nullable && (default == null || default.nullable) + override def nullable: Boolean = default == null || default.nullable override lazy val frame = { // This will be triggered by the Analyzer. @@ -381,7 +381,7 @@ abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowF self: Product => override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow) override def dataType: DataType = IntegerType - override def nullable: Boolean = false + override def nullable: Boolean = true override def supportsPartial: Boolean = false override lazy val mergeExpressions = throw new UnsupportedOperationException("Window Functions do not support merging.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5665fd7e5f..ec42b763f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -293,7 +293,14 @@ private[sql] object Expand { Literal.create(bitmask, IntegerType) }) } - Expand(projections, child.output :+ gid, child) + val output = child.output.map { attr => + if (groupByExprs.exists(_.semanticEquals(attr))) { + attr.withNullability(true) + } else { + attr + } + } + Expand(projections, output :+ gid, child) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index a98e16c253..c99a4ac964 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -297,7 +297,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from string") { assert(cast("abcdef", StringType).nullable === false) assert(cast("abcdef", BinaryType).nullable === false) - assert(cast("abcdef", BooleanType).nullable === false) + assert(cast("abcdef", BooleanType).nullable === true) assert(cast("abcdef", TimestampType).nullable === true) assert(cast("abcdef", LongType).nullable === true) assert(cast("abcdef", IntegerType).nullable === true) @@ -547,7 +547,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, Seq(null, true, false)) } @@ -606,7 +606,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { @@ -713,7 +713,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("a", BooleanType, nullable = true), StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = false)))) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, InternalRow(null, true, false)) } @@ -754,7 +754,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructType(Seq( StructField("l", LongType, nullable = true))))))) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, Row( Seq(123, null, null), Map("a" -> null, "b" -> true, "c" -> false), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 9f1b19253e..9c1688b261 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ -- cgit v1.2.3