aboutsummaryrefslogblamecommitdiff
path: root/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
blob: 93c231e30b49be4bb636de5a1752eff09d3e72a4 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
















                                                                           
                                  
 
                                        
 
                                     
                                           
 
                                                                   





                                                                                            
 
                             











                                                                          
                                                                                     








                                                                            








                                                                           




























































                                                                                              









































                                                                                        
 




















                                                                                                  




                                                                           
 
                                                                                       







                                                                                              





                                                                                                 



         





                                                                                     
 
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.types

import org.scalatest.PrivateMethodTester

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.Decimal._

class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
  /** Check that a Decimal has the given string representation, precision and scale */
  private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
    assert(d.toString === string)
    assert(d.precision === precision)
    assert(d.scale === scale)
  }

  test("creating decimals") {
    checkDecimal(new Decimal(), "0", 1, 0)
    checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3)
    checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1)
    checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1)
    checkDecimal(Decimal("10.030"), "10.030", 5, 3)
    checkDecimal(Decimal(10.03), "10.03", 4, 2)
    checkDecimal(Decimal(17L), "17", 20, 0)
    checkDecimal(Decimal(17), "17", 10, 0)
    checkDecimal(Decimal(17L, 2, 1), "1.7", 2, 1)
    checkDecimal(Decimal(170L, 4, 2), "1.70", 4, 2)
    checkDecimal(Decimal(17L, 24, 1), "1.7", 24, 1)
    checkDecimal(Decimal(1e17.toLong, 18, 0), 1e17.toLong.toString, 18, 0)
    checkDecimal(Decimal(1000000000000000000L, 20, 2), "10000000000000000.00", 20, 2)
    checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0)
    checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0)
    intercept[IllegalArgumentException](Decimal(170L, 2, 1))
    intercept[IllegalArgumentException](Decimal(170L, 2, 0))
    intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1))
    intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1))
    intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0))
  }

  test("creating decimals with negative scale") {
    checkDecimal(Decimal(BigDecimal("98765"), 5, -3), "9.9E+4", 5, -3)
    checkDecimal(Decimal(BigDecimal("314.159"), 6, -2), "3E+2", 6, -2)
    checkDecimal(Decimal(BigDecimal(1.579e12), 4, -9), "1.579E+12", 4, -9)
    checkDecimal(Decimal(BigDecimal(1.579e12), 4, -10), "1.58E+12", 4, -10)
    checkDecimal(Decimal(103050709L, 9, -10), "1.03050709E+18", 9, -10)
    checkDecimal(Decimal(1e8.toLong, 10, -10), "1.00000000E+18", 10, -10)
  }

  test("double and long values") {
    /** Check that a Decimal converts to the given double and long values */
    def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = {
      assert(d.toDouble === doubleValue)
      assert(d.toLong === longValue)
    }

    checkValues(new Decimal(), 0.0, 0L)
    checkValues(Decimal(BigDecimal("10.030")), 10.03, 10L)
    checkValues(Decimal(BigDecimal("10.030"), 4, 1), 10.0, 10L)
    checkValues(Decimal(BigDecimal("-9.95"), 4, 1), -10.0, -10L)
    checkValues(Decimal(10.03), 10.03, 10L)
    checkValues(Decimal(17L), 17.0, 17L)
    checkValues(Decimal(17), 17.0, 17L)
    checkValues(Decimal(17L, 2, 1), 1.7, 1L)
    checkValues(Decimal(170L, 4, 2), 1.7, 1L)
    checkValues(Decimal(1e16.toLong), 1e16, 1e16.toLong)
    checkValues(Decimal(1e17.toLong), 1e17, 1e17.toLong)
    checkValues(Decimal(1e18.toLong), 1e18, 1e18.toLong)
    checkValues(Decimal(2e18.toLong), 2e18, 2e18.toLong)
    checkValues(Decimal(Long.MaxValue), Long.MaxValue.toDouble, Long.MaxValue)
    checkValues(Decimal(Long.MinValue), Long.MinValue.toDouble, Long.MinValue)
    checkValues(Decimal(Double.MaxValue), Double.MaxValue, 0L)
    checkValues(Decimal(Double.MinValue), Double.MinValue, 0L)
  }

  // Accessor for the BigDecimal value of a Decimal, which will be null if it's using Longs
  private val decimalVal = PrivateMethod[BigDecimal]('decimalVal)

  /** Check whether a decimal is represented compactly (passing whether we expect it to be) */
  private def checkCompact(d: Decimal, expected: Boolean): Unit = {
    val isCompact = d.invokePrivate(decimalVal()).eq(null)
    assert(isCompact == expected, s"$d ${if (expected) "was not" else "was"} compact")
  }

  test("small decimals represented as unscaled long") {
    checkCompact(new Decimal(), true)
    checkCompact(Decimal(BigDecimal(10.03)), false)
    checkCompact(Decimal(BigDecimal(1e20)), false)
    checkCompact(Decimal(17L), true)
    checkCompact(Decimal(17), true)
    checkCompact(Decimal(17L, 2, 1), true)
    checkCompact(Decimal(170L, 4, 2), true)
    checkCompact(Decimal(17L, 24, 1), true)
    checkCompact(Decimal(1e16.toLong), true)
    checkCompact(Decimal(1e17.toLong), true)
    checkCompact(Decimal(1e18.toLong - 1), true)
    checkCompact(Decimal(- 1e18.toLong + 1), true)
    checkCompact(Decimal(1e18.toLong - 1, 30, 10), true)
    checkCompact(Decimal(- 1e18.toLong + 1, 30, 10), true)
    checkCompact(Decimal(1e18.toLong), false)
    checkCompact(Decimal(-1e18.toLong), false)
    checkCompact(Decimal(1e18.toLong, 30, 10), false)
    checkCompact(Decimal(-1e18.toLong, 30, 10), false)
    checkCompact(Decimal(Long.MaxValue), false)
    checkCompact(Decimal(Long.MinValue), false)
  }

  test("hash code") {
    assert(Decimal(123).hashCode() === (123).##)
    assert(Decimal(-123).hashCode() === (-123).##)
    assert(Decimal(Int.MaxValue).hashCode() === Int.MaxValue.##)
    assert(Decimal(Long.MaxValue).hashCode() === Long.MaxValue.##)
    assert(Decimal(BigDecimal(123)).hashCode() === (123).##)

    val reallyBig = BigDecimal("123182312312313232112312312123.1231231231")
    assert(Decimal(reallyBig).hashCode() === reallyBig.hashCode)
  }

  test("equals") {
    // The decimals on the left are stored compactly, while the ones on the right aren't
    checkCompact(Decimal(123), true)
    checkCompact(Decimal(BigDecimal(123)), false)
    checkCompact(Decimal("123"), false)
    assert(Decimal(123) === Decimal(BigDecimal(123)))
    assert(Decimal(123) === Decimal(BigDecimal("123.00")))
    assert(Decimal(-123) === Decimal(BigDecimal(-123)))
    assert(Decimal(-123) === Decimal(BigDecimal("-123.00")))
  }

  test("isZero") {
    assert(Decimal(0).isZero)
    assert(Decimal(0, 4, 2).isZero)
    assert(Decimal("0").isZero)
    assert(Decimal("0.000").isZero)
    assert(!Decimal(1).isZero)
    assert(!Decimal(1, 4, 2).isZero)
    assert(!Decimal("1").isZero)
    assert(!Decimal("0.001").isZero)
  }

  test("arithmetic") {
    assert(Decimal(100) + Decimal(-100) === Decimal(0))
    assert(Decimal(100) + Decimal(-100) === Decimal(0))
    assert(Decimal(100) * Decimal(-100) === Decimal(-10000))
    assert(Decimal(1e13) * Decimal(1e13) === Decimal(1e26))
    assert(Decimal(100) / Decimal(-100) === Decimal(-1))
    assert(Decimal(100) / Decimal(0) === null)
    assert(Decimal(100) % Decimal(-100) === Decimal(0))
    assert(Decimal(100) % Decimal(3) === Decimal(1))
    assert(Decimal(-100) % Decimal(3) === Decimal(-1))
    assert(Decimal(100) % Decimal(0) === null)
  }

  // regression test for SPARK-8359
  test("accurate precision after multiplication") {
    val decimal = (Decimal(Long.MaxValue, 38, 0) * Decimal(Long.MaxValue, 38, 0)).toJavaBigDecimal
    assert(decimal.unscaledValue.toString === "85070591730234615847396907784232501249")
  }

  // regression test for SPARK-8677
  test("fix non-terminating decimal expansion problem") {
    val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3)
    // The difference between decimal should not be more than 0.001.
    assert(decimal.toDouble - 0.333 < 0.001)
  }

  // regression test for SPARK-8800
  test("fix loss of precision/scale when doing division operation") {
    val a = Decimal(2) / Decimal(3)
    assert(a.toDouble < 1.0 && a.toDouble > 0.6)
    val b = Decimal(1) / Decimal(8)
    assert(b.toDouble === 0.125)
  }

  test("set/setOrNull") {
    assert(new Decimal().set(10L, 10, 0).toUnscaledLong === 10L)
    assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L)
    assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue)
  }

  test("changePrecision/toPrecision on compact decimal should respect rounding mode") {
    Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode =>
      Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n =>
        Seq("", "-").foreach { sign =>
          val bd = BigDecimal(sign + n)
          val unscaled = (bd * 10).toLongExact
          val d = Decimal(unscaled, 8, 1)
          assert(d.changePrecision(10, 0, mode))
          assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")

          val copy = d.toPrecision(10, 0, mode).orNull
          assert(copy !== null)
          assert(d.ne(copy))
          assert(d === copy)
          assert(copy.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")
        }
      }
    }
  }

  test("SPARK-20341: support BigInt's value does not fit in long value range") {
    val bigInt = scala.math.BigInt("9223372036854775808")
    val decimal = Decimal.apply(bigInt)
    assert(decimal.toJavaBigDecimal.unscaledValue.toString === "9223372036854775808")
  }
}