diff options
author | Wenchen Fan <wenchen@databricks.com> | 2017-01-03 22:40:14 -0800 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2017-01-03 22:40:14 -0800 |
commit | cbd11d235752d0ab30cfdbf2351cb3e68a123606 (patch) | |
tree | 38ec8095c4f42a6e10f19c3583ec8a573a60f91e | |
parent | b67b35f76b684c5176dc683e7491fd01b43f4467 (diff) | |
download | spark-cbd11d235752d0ab30cfdbf2351cb3e68a123606.tar.gz spark-cbd11d235752d0ab30cfdbf2351cb3e68a123606.tar.bz2 spark-cbd11d235752d0ab30cfdbf2351cb3e68a123606.zip |
[SPARK-19072][SQL] codegen of Literal should not output boxed value
## What changes were proposed in this pull request?
In https://github.com/apache/spark/pull/16402 we made a mistake that, when double/float is infinity, the `Literal` codegen will output boxed value and cause wrong result.
This PR fixes this by special handling infinity to not output boxed value.
## How was this patch tested?
new regression test
Author: Wenchen Fan <wenchen@databricks.com>
Closes #16469 from cloud-fan/literal.
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala | 30 | ||||
-rw-r--r-- | sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala | 5 |
2 files changed, 24 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index ab45c41bc0..cb0c4d333b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -266,33 +266,41 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def eval(input: InternalRow): Any = value override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val javaType = ctx.javaType(dataType) // change the isNull and primitive to consts, to inline them if (value == null) { ev.isNull = "true" - ev.copy(s"final ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};") + ev.copy(s"final $javaType ${ev.value} = ${ctx.defaultValue(dataType)};") } else { ev.isNull = "false" - ev.value = dataType match { - case BooleanType | IntegerType | DateType => value.toString + dataType match { + case BooleanType | IntegerType | DateType => + ev.copy(code = "", value = value.toString) case FloatType => val v = value.asInstanceOf[Float] if (v.isNaN || v.isInfinite) { - ctx.addReferenceMinorObj(v) + val boxedValue = ctx.addReferenceMinorObj(v) + val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;" + ev.copy(code = code) } else { - s"${value}f" + ev.copy(code = "", value = s"${value}f") } case DoubleType => val v = value.asInstanceOf[Double] if (v.isNaN || v.isInfinite) { - ctx.addReferenceMinorObj(v) + val boxedValue = ctx.addReferenceMinorObj(v) + val code = s"final $javaType ${ev.value} = ($javaType) $boxedValue;" + ev.copy(code = code) } else { - s"${value}D" + ev.copy(code = "", value = s"${value}D") } - case ByteType | ShortType => s"(${ctx.javaType(dataType)})$value" - case TimestampType | LongType => s"${value}L" - case other => ctx.addReferenceMinorObj(value, ctx.javaType(dataType)) + case ByteType | ShortType => + ev.copy(code = "", value = s"($javaType)$value") + case TimestampType | LongType => + ev.copy(code = "", value = s"${value}L") + case other => + ev.copy(code = "", value = ctx.addReferenceMinorObj(value, ctx.javaType(dataType))) } - ev.copy("") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 6fc3de178f..6fe295c3dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -324,4 +324,9 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { Literal.create(struct, structType), Literal.create(unsafeStruct, structType)), true) } + + test("EqualTo double/float infinity") { + val infinity = Literal(Double.PositiveInfinity) + checkEvaluation(EqualTo(infinity, infinity), true) + } } |