aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-07-20 22:38:05 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-07-20 22:38:05 -0700
commitc032b0bf92130dc4facb003f0deaeb1228aefded (patch)
tree2a48b364f0afb53e1ff10dac554b54d54abfa11b
parent4d97be95300f729391c17b4c162e3c7fba09b8bf (diff)
downloadspark-c032b0bf92130dc4facb003f0deaeb1228aefded.tar.gz
spark-c032b0bf92130dc4facb003f0deaeb1228aefded.tar.bz2
spark-c032b0bf92130dc4facb003f0deaeb1228aefded.zip
[SPARK-8797] [SPARK-9146] [SPARK-9145] [SPARK-9147] Support NaN ordering and equality comparisons in Spark SQL
This patch addresses an issue where queries that sorted float or double columns containing NaN values could fail with "Comparison method violates its general contract!" errors from TimSort. The root of this problem is that `NaN > anything`, `NaN == anything`, and `NaN < anything` all return `false`. Per the design specified in SPARK-9079, we have decided that `NaN = NaN` should return true and that NaN should appear last when sorting in ascending order (i.e. it is larger than any other numeric value). In addition to implementing these semantics, this patch also adds canonicalization of NaN values in UnsafeRow, which is necessary in order to be able to do binary equality comparisons on equal NaNs that might have different bit representations (see SPARK-9147). Author: Josh Rosen <joshrosen@databricks.com> Closes #7194 from JoshRosen/nan and squashes the following commits: 983d4fc [Josh Rosen] Merge remote-tracking branch 'origin/master' into nan 88bd73c [Josh Rosen] Fix Row.equals() a702e2e [Josh Rosen] normalization -> canonicalization a7267cf [Josh Rosen] Normalize NaNs in UnsafeRow fe629ae [Josh Rosen] Merge remote-tracking branch 'origin/master' into nan fbb2a29 [Josh Rosen] Fix NaN comparisons in BinaryComparison expressions c1fd4fe [Josh Rosen] Fold NaN test into existing test framework b31eb19 [Josh Rosen] Uncomment failing tests 7fe67af [Josh Rosen] Support NaN == NaN (SPARK-9145) 58bad2c [Josh Rosen] Revert "Compare rows' string representations to work around NaN incomparability." fc6b4d2 [Josh Rosen] Update CodeGenerator 3998ef2 [Josh Rosen] Remove unused code a2ba2e7 [Josh Rosen] Fix prefix comparision for NaNs a30d371 [Josh Rosen] Compare rows' string representations to work around NaN incomparability. 6f03f85 [Josh Rosen] Fix bug in Double / Float ordering 42a1ad5 [Josh Rosen] Stop filtering NaNs in UnsafeExternalSortSuite bfca524 [Josh Rosen] Change ordering so that NaN is maximum value. 8d7be61 [Josh Rosen] Update randomized test to use ScalaTest's assume() b20837b [Josh Rosen] Add failing test for new NaN comparision ordering 5b88b2b [Josh Rosen] Fix compilation of CodeGenerationSuite d907b5b [Josh Rosen] Merge remote-tracking branch 'origin/master' into nan 630ebc5 [Josh Rosen] Specify an ordering for NaN values. 9bf195a [Josh Rosen] Re-enable NaNs in CodeGenerationSuite to produce more regression tests 13fc06a [Josh Rosen] Add regression test for NaN sorting issue f9efbb5 [Josh Rosen] Fix ORDER BY NULL e7dc4fb [Josh Rosen] Add very generic test for ordering 7d5c13e [Josh Rosen] Add regression test for SPARK-8782 (ORDER BY NULL) b55875a [Josh Rosen] Generate doubles and floats over entire possible range. 5acdd5c [Josh Rosen] Infinity and NaN are interesting. ab76cbd [Josh Rosen] Move code to Catalyst package. d2b4a4a [Josh Rosen] Add random data generator test utilities to Spark SQL.
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java5
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala28
-rw-r--r--core/src/test/scala/org/apache/spark/util/UtilsSuite.scala31
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala25
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala22
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala39
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala13
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala6
16 files changed, 243 insertions, 26 deletions
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
index 438742565c..bf1bc5dffb 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -23,6 +23,7 @@ import com.google.common.primitives.UnsignedBytes;
import org.apache.spark.annotation.Private;
import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.util.Utils;
@Private
public class PrefixComparators {
@@ -82,7 +83,7 @@ public class PrefixComparators {
public int compare(long aPrefix, long bPrefix) {
float a = Float.intBitsToFloat((int) aPrefix);
float b = Float.intBitsToFloat((int) bPrefix);
- return (a < b) ? -1 : (a > b) ? 1 : 0;
+ return Utils.nanSafeCompareFloats(a, b);
}
public long computePrefix(float value) {
@@ -97,7 +98,7 @@ public class PrefixComparators {
public int compare(long aPrefix, long bPrefix) {
double a = Double.longBitsToDouble(aPrefix);
double b = Double.longBitsToDouble(bPrefix);
- return (a < b) ? -1 : (a > b) ? 1 : 0;
+ return Utils.nanSafeCompareDoubles(a, b);
}
public long computePrefix(double value) {
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index e6374f17d8..c5816949cd 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1586,6 +1586,34 @@ private[spark] object Utils extends Logging {
hashAbs
}
+ /**
+ * NaN-safe version of [[java.lang.Double.compare()]] which allows NaN values to be compared
+ * according to semantics where NaN == NaN and NaN > any non-NaN double.
+ */
+ def nanSafeCompareDoubles(x: Double, y: Double): Int = {
+ val xIsNan: Boolean = java.lang.Double.isNaN(x)
+ val yIsNan: Boolean = java.lang.Double.isNaN(y)
+ if ((xIsNan && yIsNan) || (x == y)) 0
+ else if (xIsNan) 1
+ else if (yIsNan) -1
+ else if (x > y) 1
+ else -1
+ }
+
+ /**
+ * NaN-safe version of [[java.lang.Float.compare()]] which allows NaN values to be compared
+ * according to semantics where NaN == NaN and NaN > any non-NaN float.
+ */
+ def nanSafeCompareFloats(x: Float, y: Float): Int = {
+ val xIsNan: Boolean = java.lang.Float.isNaN(x)
+ val yIsNan: Boolean = java.lang.Float.isNaN(y)
+ if ((xIsNan && yIsNan) || (x == y)) 0
+ else if (xIsNan) 1
+ else if (yIsNan) -1
+ else if (x > y) 1
+ else -1
+ }
+
/** Returns the system properties map that is thread-safe to iterator over. It gets the
* properties which have been set explicitly, as well as those for which only a default value
* has been defined. */
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index c7638507c8..8f7e402d5f 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util
import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream}
+import java.lang.{Double => JDouble, Float => JFloat}
import java.net.{BindException, ServerSocket, URI}
import java.nio.{ByteBuffer, ByteOrder}
import java.text.DecimalFormatSymbols
@@ -689,4 +690,34 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
// scalastyle:on println
assert(buffer.toString === "t circular test circular\n")
}
+
+ test("nanSafeCompareDoubles") {
+ def shouldMatchDefaultOrder(a: Double, b: Double): Unit = {
+ assert(Utils.nanSafeCompareDoubles(a, b) === JDouble.compare(a, b))
+ assert(Utils.nanSafeCompareDoubles(b, a) === JDouble.compare(b, a))
+ }
+ shouldMatchDefaultOrder(0d, 0d)
+ shouldMatchDefaultOrder(0d, 1d)
+ shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue)
+ assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NaN) === 0)
+ assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.PositiveInfinity) === 1)
+ assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NegativeInfinity) === 1)
+ assert(Utils.nanSafeCompareDoubles(Double.PositiveInfinity, Double.NaN) === -1)
+ assert(Utils.nanSafeCompareDoubles(Double.NegativeInfinity, Double.NaN) === -1)
+ }
+
+ test("nanSafeCompareFloats") {
+ def shouldMatchDefaultOrder(a: Float, b: Float): Unit = {
+ assert(Utils.nanSafeCompareFloats(a, b) === JFloat.compare(a, b))
+ assert(Utils.nanSafeCompareFloats(b, a) === JFloat.compare(b, a))
+ }
+ shouldMatchDefaultOrder(0f, 0f)
+ shouldMatchDefaultOrder(1f, 1f)
+ shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue)
+ assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NaN) === 0)
+ assert(Utils.nanSafeCompareFloats(Float.NaN, Float.PositiveInfinity) === 1)
+ assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NegativeInfinity) === 1)
+ assert(Utils.nanSafeCompareFloats(Float.PositiveInfinity, Float.NaN) === -1)
+ assert(Utils.nanSafeCompareFloats(Float.NegativeInfinity, Float.NaN) === -1)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
index dd505dfa7d..dc03e374b5 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
@@ -47,4 +47,29 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
}
+
+ test("float prefix comparator handles NaN properly") {
+ val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001)
+ val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff)
+ assert(nan1.isNaN)
+ assert(nan2.isNaN)
+ val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1)
+ val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2)
+ assert(nan1Prefix === nan2Prefix)
+ val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue)
+ assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1)
+ }
+
+ test("double prefix comparator handles NaNs properly") {
+ val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
+ val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)
+ assert(nan1.isNaN)
+ assert(nan2.isNaN)
+ val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1)
+ val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2)
+ assert(nan1Prefix === nan2Prefix)
+ val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue)
+ assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1)
+ }
+
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 87294a0e21..8cd9e7bc60 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -215,6 +215,9 @@ public final class UnsafeRow extends MutableRow {
public void setDouble(int ordinal, double value) {
assertIndexIsValid(ordinal);
setNotNullAt(ordinal);
+ if (Double.isNaN(value)) {
+ value = Double.NaN;
+ }
PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value);
}
@@ -243,6 +246,9 @@ public final class UnsafeRow extends MutableRow {
public void setFloat(int ordinal, float value) {
assertIndexIsValid(ordinal);
setNotNullAt(ordinal);
+ if (Float.isNaN(value)) {
+ value = Float.NaN;
+ }
PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 2cb64d0093..91449479fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -403,20 +403,28 @@ trait Row extends Serializable {
if (!isNullAt(i)) {
val o1 = get(i)
val o2 = other.get(i)
- if (o1.isInstanceOf[Array[Byte]]) {
- // handle equality of Array[Byte]
- val b1 = o1.asInstanceOf[Array[Byte]]
- if (!o2.isInstanceOf[Array[Byte]] ||
- !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
+ o1 match {
+ case b1: Array[Byte] =>
+ if (!o2.isInstanceOf[Array[Byte]] ||
+ !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
+ return false
+ }
+ case f1: Float if java.lang.Float.isNaN(f1) =>
+ if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
+ return false
+ }
+ case d1: Double if java.lang.Double.isNaN(d1) =>
+ if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
+ return false
+ }
+ case _ => if (o1 != o2) {
return false
}
- } else if (o1 != o2) {
- return false
}
}
i += 1
}
- return true
+ true
}
/* ---------------------- utility methods for Scala ---------------------- */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 10f411ff74..606f770cb4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -194,6 +194,8 @@ class CodeGenContext {
*/
def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
case BinaryType => s"java.util.Arrays.equals($c1, $c2)"
+ case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2"
+ case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2"
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
case other => s"$c1.equals($c2)"
}
@@ -204,6 +206,8 @@ class CodeGenContext {
def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
// java boolean doesn't support > or < operator
case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))"
+ case DoubleType => s"org.apache.spark.util.Utils.nanSafeCompareDoubles($c1, $c2)"
+ case FloatType => s"org.apache.spark.util.Utils.nanSafeCompareFloats($c1, $c2)"
// use c1 - c2 may overflow
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 40ec3df224..a53ec31ee6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
object InterpretedPredicate {
@@ -222,7 +223,9 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
abstract class BinaryComparison extends BinaryOperator with Predicate {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- if (ctx.isPrimitiveType(left.dataType)) {
+ if (ctx.isPrimitiveType(left.dataType)
+ && left.dataType != FloatType
+ && left.dataType != DoubleType) {
// faster version
defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2")
} else {
@@ -254,8 +257,15 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
override def symbol: String = "="
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
- if (left.dataType != BinaryType) input1 == input2
- else java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
+ if (left.dataType == FloatType) {
+ Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
+ } else if (left.dataType == DoubleType) {
+ Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
+ } else if (left.dataType != BinaryType) {
+ input1 == input2
+ } else {
+ java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
+ }
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -280,7 +290,11 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
} else if (input1 == null || input2 == null) {
false
} else {
- if (left.dataType != BinaryType) {
+ if (left.dataType == FloatType) {
+ Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
+ } else if (left.dataType == DoubleType) {
+ Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
+ } else if (left.dataType != BinaryType) {
input1 == input2
} else {
java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
index 986c2ab055..2a1bf0938e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
@@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@@ -37,7 +38,9 @@ class DoubleType private() extends FractionalType {
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
private[sql] val numeric = implicitly[Numeric[Double]]
private[sql] val fractional = implicitly[Fractional[Double]]
- private[sql] val ordering = implicitly[Ordering[InternalType]]
+ private[sql] val ordering = new Ordering[Double] {
+ override def compare(x: Double, y: Double): Int = Utils.nanSafeCompareDoubles(x, y)
+ }
private[sql] val asIntegral = DoubleAsIfIntegral
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala
index 9bd48ece83..08e22252ae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala
@@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@@ -37,7 +38,9 @@ class FloatType private() extends FractionalType {
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
private[sql] val numeric = implicitly[Numeric[Float]]
private[sql] val fractional = implicitly[Fractional[Float]]
- private[sql] val ordering = implicitly[Ordering[InternalType]]
+ private[sql] val ordering = new Ordering[Float] {
+ override def compare(x: Float, y: Float): Int = Utils.nanSafeCompareFloats(x, y)
+ }
private[sql] val asIntegral = FloatAsIfIntegral
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index e05218a23a..f4fbc49677 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -17,9 +17,14 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.math._
+
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.types.{DataTypeTestUtils, NullType, StructField, StructType}
/**
* Additional tests for code generation.
@@ -43,6 +48,40 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
futures.foreach(Await.result(_, 10.seconds))
}
+ // Test GenerateOrdering for all common types. For each type, we construct random input rows that
+ // contain two columns of that type, then for pairs of randomly-generated rows we check that
+ // GenerateOrdering agrees with RowOrdering.
+ (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType =>
+ test(s"GenerateOrdering with $dataType") {
+ val rowOrdering = RowOrdering.forSchema(Seq(dataType, dataType))
+ val genOrdering = GenerateOrdering.generate(
+ BoundReference(0, dataType, nullable = true).asc ::
+ BoundReference(1, dataType, nullable = true).asc :: Nil)
+ val rowType = StructType(
+ StructField("a", dataType, nullable = true) ::
+ StructField("b", dataType, nullable = true) :: Nil)
+ val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false)
+ assume(maybeDataGenerator.isDefined)
+ val randGenerator = maybeDataGenerator.get
+ val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType)
+ for (_ <- 1 to 50) {
+ val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
+ val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
+ withClue(s"a = $a, b = $b") {
+ assert(genOrdering.compare(a, a) === 0)
+ assert(genOrdering.compare(b, b) === 0)
+ assert(rowOrdering.compare(a, a) === 0)
+ assert(rowOrdering.compare(b, b) === 0)
+ assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a)))
+ assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a)))
+ assert(
+ signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)),
+ "Generated and non-generated orderings should agree")
+ }
+ }
+ }
+ }
+
test("SPARK-8443: split wide projections into blocks due to JVM code size limit") {
val length = 5000
val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1)))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 2173a0c25c..0bc2812a5d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -136,11 +136,14 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
}
- private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_))
- private val largeValues = Seq(2, Decimal(2), Array(2.toByte), "b").map(Literal(_))
-
- private val equalValues1 = smallValues
- private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_))
+ private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d).map(Literal(_))
+ private val largeValues =
+ Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN).map(Literal(_))
+
+ private val equalValues1 =
+ Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
+ private val equalValues2 =
+ Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
test("BinaryComparison: <") {
for (i <- 0 until smallValues.length) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index d00aeb4dfb..dff5faf9f6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -316,4 +316,26 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
}
+ test("NaN canonicalization") {
+ val fieldTypes: Array[DataType] = Array(FloatType, DoubleType)
+
+ val row1 = new SpecificMutableRow(fieldTypes)
+ row1.setFloat(0, java.lang.Float.intBitsToFloat(0x7f800001))
+ row1.setDouble(1, java.lang.Double.longBitsToDouble(0x7ff0000000000001L))
+
+ val row2 = new SpecificMutableRow(fieldTypes)
+ row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff))
+ row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL))
+
+ val converter = new UnsafeRowConverter(fieldTypes)
+ val row1Buffer = new Array[Byte](converter.getSizeRequirement(row1))
+ val row2Buffer = new Array[Byte](converter.getSizeRequirement(row2))
+ converter.writeRow(
+ row1, row1Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row1Buffer.length, null)
+ converter.writeRow(
+ row2, row2Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row2Buffer.length, null)
+
+ assert(row1Buffer.toSeq === row2Buffer.toSeq)
+ }
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 192cc0a6e5..f67f2c60c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.io.File
import scala.language.postfixOps
+import scala.util.Random
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
import org.apache.spark.sql.functions._
@@ -742,6 +743,27 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
df.col("t.``")
}
+ test("SPARK-8797: sort by float column containing NaN should not crash") {
+ val inputData = Seq.fill(10)(Tuple1(Float.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toFloat))
+ val df = Random.shuffle(inputData).toDF("a")
+ df.orderBy("a").collect()
+ }
+
+ test("SPARK-8797: sort by double column containing NaN should not crash") {
+ val inputData = Seq.fill(10)(Tuple1(Double.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toDouble))
+ val df = Random.shuffle(inputData).toDF("a")
+ df.orderBy("a").collect()
+ }
+
+ test("NaN is greater than all other non-NaN numeric values") {
+ val maxDouble = Seq(Double.NaN, Double.PositiveInfinity, Double.MaxValue)
+ .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first()
+ assert(java.lang.Double.isNaN(maxDouble.getDouble(0)))
+ val maxFloat = Seq(Float.NaN, Float.PositiveInfinity, Float.MaxValue)
+ .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first()
+ assert(java.lang.Float.isNaN(maxFloat.getFloat(0)))
+ }
+
test("SPARK-8072: Better Exception for Duplicate Columns") {
// only one duplicate column present
val e = intercept[org.apache.spark.sql.AnalysisException] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index d84b57af9c..7cc6ffd754 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -73,4 +73,16 @@ class RowSuite extends SparkFunSuite {
row.getAs[Int]("c")
}
}
+
+ test("float NaN == NaN") {
+ val r1 = Row(Float.NaN)
+ val r2 = Row(Float.NaN)
+ assert(r1 === r2)
+ }
+
+ test("double NaN == NaN") {
+ val r1 = Row(Double.NaN)
+ val r2 = Row(Double.NaN)
+ assert(r1 === r2)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
index 4f4c1f2856..5fe73f7e0b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
@@ -83,11 +83,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
) {
test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
- val inputData = Seq.fill(1000)(randomDataGenerator()).filter {
- case d: Double => !d.isNaN
- case f: Float => !java.lang.Float.isNaN(f)
- case x => true
- }
+ val inputData = Seq.fill(1000)(randomDataGenerator())
val inputDf = TestSQLContext.createDataFrame(
TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
StructType(StructField("a", dataType, nullable = true) :: Nil)