aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala13
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala3
5 files changed, 30 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index afecf881c7..5eb5b0d176 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -67,6 +67,19 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
(DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
s"$DoublePrefixCmp.computePrefix((double)$input)")
case StringType => (0L, s"$input.getPrefix()")
+ case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
+ val prefix = if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
+ s"$input.toUnscaledLong()"
+ } else {
+ // reduce the scale to fit in a long
+ val p = Decimal.MAX_LONG_DIGITS
+ val s = p - (dt.precision - dt.scale)
+ s"$input.changePrecision($p, $s) ? $input.toUnscaledLong() : ${Long.MinValue}L"
+ }
+ (Long.MinValue, prefix)
+ case dt: DecimalType =>
+ (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
+ s"$DoublePrefixCmp.computePrefix($input.toDouble())")
case _ => (0L, "0L")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
index 81267dc915..ea1fd23d0d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
@@ -107,7 +107,10 @@ object RandomDataGenerator {
case DateType => Some(() => new java.sql.Date(rand.nextInt()))
case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong()))
case DecimalType.Fixed(precision, scale) => Some(
- () => BigDecimal.apply(rand.nextLong(), rand.nextInt(), new MathContext(precision)))
+ () => BigDecimal.apply(
+ rand.nextLong() % math.pow(10, precision).toLong,
+ scale,
+ new MathContext(precision)))
case DoubleType => randomNumeric[Double](
rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue,
Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala
index 0ee9ddac81..417df006ab 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala
@@ -34,8 +34,9 @@ object DataTypeTestUtils {
* decimal types.
*/
val fractionalTypes: Set[FractionalType] = Set(
+ DecimalType.USER_DEFAULT,
+ DecimalType(20, 5),
DecimalType.SYSTEM_DEFAULT,
- DecimalType(2, 1),
DoubleType,
FloatType
)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index 676656518f..2e870ec8ae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -36,16 +36,16 @@ object SortPrefixUtils {
def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = {
sortOrder.dataType match {
- case StringType if sortOrder.isAscending => PrefixComparators.STRING
- case StringType if !sortOrder.isAscending => PrefixComparators.STRING_DESC
- case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType
- if sortOrder.isAscending =>
- PrefixComparators.LONG
- case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType
- if !sortOrder.isAscending =>
- PrefixComparators.LONG_DESC
- case FloatType | DoubleType if sortOrder.isAscending => PrefixComparators.DOUBLE
- case FloatType | DoubleType if !sortOrder.isAscending => PrefixComparators.DOUBLE_DESC
+ case StringType =>
+ if (sortOrder.isAscending) PrefixComparators.STRING else PrefixComparators.STRING_DESC
+ case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType =>
+ if (sortOrder.isAscending) PrefixComparators.LONG else PrefixComparators.LONG_DESC
+ case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
+ if (sortOrder.isAscending) PrefixComparators.LONG else PrefixComparators.LONG_DESC
+ case FloatType | DoubleType =>
+ if (sortOrder.isAscending) PrefixComparators.DOUBLE else PrefixComparators.DOUBLE_DESC
+ case dt: DecimalType =>
+ if (sortOrder.isAscending) PrefixComparators.DOUBLE else PrefixComparators.DOUBLE_DESC
case _ => NoOpPrefixComparator
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
index b3f821e0cd..c794984851 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
@@ -61,8 +61,7 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
// Test sorting on different data types
for (
- dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType)
- if !dataType.isInstanceOf[DecimalType]; // We don't have an unsafe representation for decimals
+ dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType);
nullable <- Seq(true, false);
sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil);
randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)