aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-01-18 11:01:42 -0800
committerReynold Xin <rxin@databricks.com>2015-01-18 11:01:42 -0800
commit1727e0841cf9948e601ae2936fe89094c8c0c835 (patch)
tree6017e327a6ca2b7a42c8a102d1c6c3a889aa8628 /sql
parentad16da1bcc500d0fe594853cd00470dc34b007fa (diff)
downloadspark-1727e0841cf9948e601ae2936fe89094c8c0c835.tar.gz
spark-1727e0841cf9948e601ae2936fe89094c8c0c835.tar.bz2
spark-1727e0841cf9948e601ae2936fe89094c8c0c835.zip
[SPARK-5279][SQL] Use java.math.BigDecimal as the exposed Decimal type.
Author: Reynold Xin <rxin@databricks.com> Closes #4092 from rxin/bigdecimal and squashes the following commits: 27b08c9 [Reynold Xin] Fixed test. 10cb496 [Reynold Xin] [SPARK-5279][SQL] Use java.math.BigDecimal as the exposed Decimal type.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala6
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala4
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala8
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala38
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala8
-rw-r--r--sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala2
-rw-r--r--sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala5
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala4
-rw-r--r--sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala2
-rw-r--r--sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala2
27 files changed, 101 insertions, 77 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index a28a1e90dd..208ec92987 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -116,7 +116,7 @@ trait Row extends Seq[Any] with Serializable {
* FloatType -> java.lang.Float
* DoubleType -> java.lang.Double
* StringType -> String
- * DecimalType -> scala.math.BigDecimal
+ * DecimalType -> java.math.BigDecimal
*
* DateType -> java.sql.Date
* TimestampType -> java.sql.Timestamp
@@ -141,7 +141,7 @@ trait Row extends Seq[Any] with Serializable {
* FloatType -> java.lang.Float
* DoubleType -> java.lang.Double
* StringType -> String
- * DecimalType -> scala.math.BigDecimal
+ * DecimalType -> java.math.BigDecimal
*
* DateType -> java.sql.Date
* TimestampType -> java.sql.Timestamp
@@ -227,7 +227,7 @@ trait Row extends Seq[Any] with Serializable {
*
* @throws ClassCastException when data type does not match.
*/
- def getDecimal(i: Int): BigDecimal = apply(i).asInstanceOf[BigDecimal]
+ def getDecimal(i: Int): java.math.BigDecimal = apply(i).asInstanceOf[java.math.BigDecimal]
/**
* Returns the value at position i of date type as java.sql.Date.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 697bacfedc..d280db83b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -66,6 +66,7 @@ trait ScalaReflection {
convertToCatalyst(elem, field.dataType)
}.toArray)
case (d: BigDecimal, _) => Decimal(d)
+ case (d: java.math.BigDecimal, _) => Decimal(d)
case (other, _) => other
}
@@ -78,7 +79,7 @@ trait ScalaReflection {
convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
}
case (r: Row, s: StructType) => convertRowToScala(r, s)
- case (d: Decimal, _: DecimalType) => d.toBigDecimal
+ case (d: Decimal, _: DecimalType) => d.toJavaBigDecimal
case (other, _) => other
}
@@ -152,6 +153,7 @@ trait ScalaReflection {
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
+ case t if t <:< typeOf[java.math.BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
@@ -182,7 +184,7 @@ trait ScalaReflection {
case obj: FloatType.JvmType => FloatType
case obj: DoubleType.JvmType => DoubleType
case obj: DateType.JvmType => DateType
- case obj: BigDecimal => DecimalType.Unlimited
+ case obj: java.math.BigDecimal => DecimalType.Unlimited
case obj: Decimal => DecimalType.Unlimited
case obj: TimestampType.JvmType => TimestampType
case null => NullType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index d19563e95c..0b36d8b9bf 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -343,13 +343,13 @@ class SqlParser extends AbstractSparkSQLParser {
| floatLit ^^ { f => Literal(f.toDouble) }
)
- private def toNarrowestIntegerType(value: String) = {
+ private def toNarrowestIntegerType(value: String): Any = {
val bigIntValue = BigDecimal(value)
bigIntValue match {
case v if bigIntValue.isValidInt => v.toIntExact
case v if bigIntValue.isValidLong => v.toLongExact
- case v => v
+ case v => v.underlying()
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 15353361d9..6ef8577fd0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -403,8 +403,8 @@ trait HiveTypeCoercion {
* Changes Boolean values to Bytes so that expressions like true < false can be Evaluated.
*/
object BooleanComparisons extends Rule[LogicalPlan] {
- val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, BigDecimal(1)).map(Literal(_))
- val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, BigDecimal(0)).map(Literal(_))
+ val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, new java.math.BigDecimal(1)).map(Literal(_))
+ val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, new java.math.BigDecimal(0)).map(Literal(_))
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 8bc36a238d..26c855878d 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -126,7 +126,8 @@ package object dsl {
implicit def doubleToLiteral(d: Double): Literal = Literal(d)
implicit def stringToLiteral(s: String): Literal = Literal(s)
implicit def dateToLiteral(d: Date): Literal = Literal(d)
- implicit def bigDecimalToLiteral(d: BigDecimal): Literal = Literal(d)
+ implicit def bigDecimalToLiteral(d: BigDecimal): Literal = Literal(d.underlying())
+ implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Literal = Literal(d)
implicit def decimalToLiteral(d: Decimal): Literal = Literal(d)
implicit def timestampToLiteral(t: Timestamp): Literal = Literal(t)
implicit def binaryToLiteral(a: Array[Byte]): Literal = Literal(a)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index c94a947fb2..5b389aad7a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -32,6 +32,7 @@ object Literal {
case s: String => Literal(s, StringType)
case b: Boolean => Literal(b, BooleanType)
case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
+ case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
case d: Decimal => Literal(d, DecimalType.Unlimited)
case t: Timestamp => Literal(t, TimestampType)
case d: Date => Literal(d, DateType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala
index 08bb933a2b..21f478c80c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeConversions.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.types
import java.text.SimpleDateFormat
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.types.Decimal
protected[sql] object DataTypeConversions {
@@ -56,13 +55,7 @@ protected[sql] object DataTypeConversions {
/** Converts Java objects to catalyst rows / types */
def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type
- case (d: java.math.BigDecimal, _) => Decimal(BigDecimal(d))
+ case (d: java.math.BigDecimal, _) => Decimal(d)
case (other, _) => other
}
-
- /** Converts Java objects to catalyst rows / types */
- def convertCatalystToJava(a: Any): Any = a match {
- case d: scala.math.BigDecimal => d.underlying()
- case other => other
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index a85c4316e1..21cc6cea4b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -143,7 +143,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
}
- def toJavaBigDecimal: java.math.BigDecimal = toBigDecimal.bigDecimal
+ def toJavaBigDecimal: java.math.BigDecimal = toBigDecimal.underlying()
def toUnscaledLong: Long = {
if (decimalVal.ne(null)) {
@@ -298,6 +298,8 @@ object Decimal {
def apply(value: BigDecimal): Decimal = new Decimal().set(value)
+ def apply(value: java.math.BigDecimal): Decimal = new Decimal().set(value)
+
def apply(value: BigDecimal, precision: Int, scale: Int): Decimal =
new Decimal().set(value, precision, scale)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
index e1cbe6650a..bcd74603d4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
@@ -501,7 +501,7 @@ case class PrecisionInfo(precision: Int, scale: Int)
/**
* :: DeveloperApi ::
*
- * The data type representing `scala.math.BigDecimal` values.
+ * The data type representing `java.math.BigDecimal` values.
* A Decimal that might have fixed precision and scale, or unlimited values for these.
*
* Please use [[DataTypes.createDecimalType()]] to create a specific instance.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 117725df32..6df5db4c80 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -43,7 +43,7 @@ case class NullableData(
byteField: java.lang.Byte,
booleanField: java.lang.Boolean,
stringField: String,
- decimalField: BigDecimal,
+ decimalField: java.math.BigDecimal,
dateField: Date,
timestampField: Timestamp,
binaryField: Array[Byte])
@@ -204,7 +204,8 @@ class ScalaReflectionSuite extends FunSuite {
assert(DoubleType === typeOfObject(1.7976931348623157E308))
// DecimalType
- assert(DecimalType.Unlimited === typeOfObject(BigDecimal("1.7976931348623157E318")))
+ assert(DecimalType.Unlimited ===
+ typeOfObject(new java.math.BigDecimal("1.7976931348623157E318")))
// DateType
assert(DateType === typeOfObject(Date.valueOf("2014-07-25")))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 7a0249137a..30564e14fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -46,7 +46,8 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
new HyperLogLogSerializer)
- kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
+ kryo.register(classOf[java.math.BigDecimal], new JavaBigDecimalSerializer)
+ kryo.register(classOf[BigDecimal], new ScalaBigDecimalSerializer)
// Specific hashsets must come first TODO: Move to core.
kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer)
@@ -99,14 +100,25 @@ private[sql] object SparkSqlSerializer {
}
}
-private[sql] class BigDecimalSerializer extends Serializer[BigDecimal] {
- def write(kryo: Kryo, output: Output, bd: math.BigDecimal) {
+private[sql] class JavaBigDecimalSerializer extends Serializer[java.math.BigDecimal] {
+ def write(kryo: Kryo, output: Output, bd: java.math.BigDecimal) {
// TODO: There are probably more efficient representations than strings...
- output.writeString(bd.toString())
+ output.writeString(bd.toString)
+ }
+
+ def read(kryo: Kryo, input: Input, tpe: Class[java.math.BigDecimal]): java.math.BigDecimal = {
+ new java.math.BigDecimal(input.readString())
+ }
+}
+
+private[sql] class ScalaBigDecimalSerializer extends Serializer[BigDecimal] {
+ def write(kryo: Kryo, output: Output, bd: BigDecimal) {
+ // TODO: There are probably more efficient representations than strings...
+ output.writeString(bd.toString)
}
def read(kryo: Kryo, input: Input, tpe: Class[BigDecimal]): BigDecimal = {
- BigDecimal(input.readString())
+ new java.math.BigDecimal(input.readString())
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 741ccb8fb8..7ed64aad10 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -135,9 +135,7 @@ object EvaluatePython {
case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType)
- case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal
-
- // Pyrolite can handle Timestamp
+ // Pyrolite can handle Timestamp and Decimal
case (other, _) => other
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index 453b560ff8..db70a7eac7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -18,11 +18,10 @@
package org.apache.spark.sql.json
import java.io.StringWriter
+import java.sql.{Date, Timestamp}
import scala.collection.Map
import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper}
-import scala.math.BigDecimal
-import java.sql.{Date, Timestamp}
import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.core.JsonFactory
@@ -333,9 +332,9 @@ private[sql] object JsonRDD extends Logging {
value match {
case value: java.lang.Integer => Decimal(value)
case value: java.lang.Long => Decimal(value)
- case value: java.math.BigInteger => Decimal(BigDecimal(value))
+ case value: java.math.BigInteger => Decimal(new java.math.BigDecimal(value))
case value: java.lang.Double => Decimal(value)
- case value: java.math.BigDecimal => Decimal(BigDecimal(value))
+ case value: java.math.BigDecimal => Decimal(value)
}
}
@@ -446,7 +445,6 @@ private[sql] object JsonRDD extends Logging {
case (FloatType, v: Float) => gen.writeNumber(v)
case (DoubleType, v: Double) => gen.writeNumber(v)
case (LongType, v: Long) => gen.writeNumber(v)
- case (DecimalType(), v: scala.math.BigDecimal) => gen.writeNumber(v.bigDecimal)
case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v)
case (ByteType, v: Byte) => gen.writeNumber(v.toInt)
case (BinaryType, v: Array[Byte]) => gen.writeBinary(v)
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
index 86d21f49fe..9e96738ac0 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
@@ -130,7 +130,7 @@ public class JavaApplySchemaSuite implements Serializable {
List<Row> expectedResult = new ArrayList<Row>(2);
expectedResult.add(
RowFactory.create(
- scala.math.BigDecimal$.MODULE$.apply("92233720368547758070"),
+ new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
@@ -139,7 +139,7 @@ public class JavaApplySchemaSuite implements Serializable {
"this is a simple string."));
expectedResult.add(
RowFactory.create(
- scala.math.BigDecimal$.MODULE$.apply("92233720368547758069"),
+ new java.math.BigDecimal("92233720368547758069"),
false,
1.7976931348623157E305,
11,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index efe622f8bc..2bcfe28456 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -207,17 +207,17 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
decimalData.aggregate(avg('a)),
- BigDecimal(2.0))
+ new java.math.BigDecimal(2.0))
checkAnswer(
decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial
- (BigDecimal(2.0), BigDecimal(6)) :: Nil)
+ (new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
checkAnswer(
decimalData.aggregate(avg('a cast DecimalType(10, 2))),
- BigDecimal(2.0))
+ new java.math.BigDecimal(2.0))
checkAnswer(
decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial
- (BigDecimal(2.0), BigDecimal(6)) :: Nil)
+ (new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
}
test("null average") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 3d9f0cbf80..68ddecc7f6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -57,7 +57,19 @@ class QueryTest extends PlanTest {
}
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
- def prepareAnswer(answer: Seq[Any]) = if (!isSorted) answer.sortBy(_.toString) else answer
+ def prepareAnswer(answer: Seq[Any]): Seq[Any] = {
+ // Converts data to types that we can do equality comparison using Scala collections.
+ // For BigDecimal type, the Scala type has a better definition of equality test (similar to
+ // Java's java.math.BigDecimal.compareTo).
+ val converted = answer.map {
+ case s: Seq[_] => s.map {
+ case d: java.math.BigDecimal => BigDecimal(d)
+ case o => o
+ }
+ case o => o
+ }
+ if (!isSorted) converted.sortBy(_.toString) else converted
+ }
val sparkAnswer = try rdd.collect().toSeq catch {
case e: Exception =>
fail(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 6c95bad697..54fabc5c91 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -844,11 +844,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
)
checkAnswer(
- sql("SELECT 9223372036854775808"), BigDecimal("9223372036854775808")
+ sql("SELECT 9223372036854775808"), new java.math.BigDecimal("9223372036854775808")
)
checkAnswer(
- sql("SELECT -9223372036854775809"), BigDecimal("-9223372036854775809")
+ sql("SELECT -9223372036854775809"), new java.math.BigDecimal("-9223372036854775809")
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index 40fb8d5779..ee381da491 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -33,7 +33,7 @@ case class ReflectData(
shortField: Short,
byteField: Byte,
booleanField: Boolean,
- decimalField: BigDecimal,
+ decimalField: java.math.BigDecimal,
date: Date,
timestampField: Timestamp,
seqInt: Seq[Int])
@@ -77,13 +77,13 @@ case class ComplexReflectData(
class ScalaReflectionRelationSuite extends FunSuite {
test("query case class RDD") {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
- BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))
+ new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))
val rdd = sparkContext.parallelize(data :: Nil)
rdd.registerTempTable("reflectData")
assert(sql("SELECT * FROM reflectData").collect().head ===
Seq("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
- BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)))
+ new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)))
}
test("query case class RDD with nulls") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 497897c3c0..808ed5288c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -59,6 +59,7 @@ object TestData {
testData2.registerTempTable("testData2")
case class DecimalData(a: BigDecimal, b: BigDecimal)
+
val decimalData =
TestSQLContext.sparkContext.parallelize(
DecimalData(1, 1) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 1dd85a3bb4..2bc9aede32 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -229,7 +229,7 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable"),
- (BigDecimal("92233720368547758070"),
+ (new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
@@ -283,7 +283,8 @@ class JsonSuite extends QueryTest {
// Access elements of a BigInteger array (we use DecimalType internally).
checkAnswer(
sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from jsonTable"),
- (BigDecimal("922337203685477580700"), BigDecimal("-922337203685477580800"), null) :: Nil
+ (new java.math.BigDecimal("922337203685477580700"),
+ new java.math.BigDecimal("-922337203685477580800"), null) :: Nil
)
// Access elements of an array of arrays.
@@ -318,9 +319,9 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select struct, struct.field1, struct.field2 from jsonTable"),
Row(
- Row(true, BigDecimal("92233720368547758070")),
+ Row(true, new java.math.BigDecimal("92233720368547758070")),
true,
- BigDecimal("92233720368547758070")) :: Nil
+ new java.math.BigDecimal("92233720368547758070")) :: Nil
)
// Access an array field of a struct.
@@ -372,9 +373,9 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable"),
("true", 11L, null, 1.1, "13.1", "str1") ::
- ("12", null, BigDecimal("21474836470.9"), null, null, "true") ::
- ("false", 21474836470L, BigDecimal("92233720368547758070"), 100, "str1", "false") ::
- (null, 21474836570L, BigDecimal(1.1), 21474836470L, "92233720368547758070", null) :: Nil
+ ("12", null, new java.math.BigDecimal("21474836470.9"), null, null, "true") ::
+ ("false", 21474836470L, new java.math.BigDecimal("92233720368547758070"), 100, "str1", "false") ::
+ (null, 21474836570L, new java.math.BigDecimal("1.1"), 21474836470L, "92233720368547758070", null) :: Nil
)
// Number and Boolean conflict: resolve the type as number in this query.
@@ -397,7 +398,7 @@ class JsonSuite extends QueryTest {
// Widening to DecimalType
checkAnswer(
sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"),
- Seq(BigDecimal("21474836472.1")) :: Seq(BigDecimal("92233720368547758071.2")) :: Nil
+ Seq(new java.math.BigDecimal("21474836472.1")) :: Seq(new java.math.BigDecimal("92233720368547758071.2")) :: Nil
)
// Widening to DoubleType
@@ -415,7 +416,7 @@ class JsonSuite extends QueryTest {
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"),
- BigDecimal("92233720368547758061.2").toDouble
+ new java.math.BigDecimal("92233720368547758061.2").doubleValue
)
// String and Boolean conflict: resolve the type as string.
@@ -463,7 +464,7 @@ class JsonSuite extends QueryTest {
jsonSchemaRDD.
where('num_str > BigDecimal("92233720368547758060")).
select('num_str + 1.2 as Symbol("num")),
- BigDecimal("92233720368547758061.2")
+ new java.math.BigDecimal("92233720368547758061.2")
)
// The following test will fail. The type of num_str is StringType.
@@ -567,7 +568,7 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable"),
- (BigDecimal("92233720368547758070"),
+ (new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
@@ -593,7 +594,7 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTableSQL"),
- (BigDecimal("92233720368547758070"),
+ (new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
@@ -625,7 +626,7 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable1"),
- (BigDecimal("92233720368547758070"),
+ (new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
@@ -642,7 +643,7 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select * from jsonTable2"),
- (BigDecimal("92233720368547758070"),
+ (new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
@@ -850,7 +851,7 @@ class JsonSuite extends QueryTest {
primTable.registerTempTable("primativeTable")
checkAnswer(
sql("select * from primativeTable"),
- (BigDecimal("92233720368547758070"),
+ (new java.math.BigDecimal("92233720368547758070"),
true,
1.7976931348623157E308,
10,
@@ -876,7 +877,8 @@ class JsonSuite extends QueryTest {
// Access elements of a BigInteger array (we use DecimalType internally).
checkAnswer(
sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from complexTable"),
- (BigDecimal("922337203685477580700"), BigDecimal("-922337203685477580800"), null) :: Nil
+ (new java.math.BigDecimal("922337203685477580700"),
+ new java.math.BigDecimal("-922337203685477580800"), null) :: Nil
)
// Access elements of an array of arrays.
@@ -901,9 +903,9 @@ class JsonSuite extends QueryTest {
checkAnswer(
sql("select struct, struct.field1, struct.field2 from complexTable"),
Row(
- Row(true, BigDecimal("92233720368547758070")),
+ Row(true, new java.math.BigDecimal("92233720368547758070")),
true,
- BigDecimal("92233720368547758070")) :: Nil
+ new java.math.BigDecimal("92233720368547758070")) :: Nil
)
// Access an array field of a struct.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 382dddcdea..264f6d94c4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -70,8 +70,8 @@ case class AllDataTypesScan(
i.toLong,
i.toFloat,
i.toDouble,
- BigDecimal(i),
- BigDecimal(i),
+ new java.math.BigDecimal(i),
+ new java.math.BigDecimal(i),
new Date((i + 1) * 8640000),
new Timestamp(20000 + i),
s"varchar_$i",
@@ -99,8 +99,8 @@ class TableScanSuite extends DataSourceTest {
i.toLong,
i.toFloat,
i.toDouble,
- BigDecimal(i),
- BigDecimal(i),
+ new java.math.BigDecimal(i),
+ new java.math.BigDecimal(i),
new Date((i + 1) * 8640000),
new Timestamp(20000 + i),
s"varchar_$i",
diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
index 171d707b13..166c56b9df 100644
--- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
+++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
@@ -122,7 +122,7 @@ private[hive] class SparkExecuteStatementOperation(
case FloatType =>
to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal)))
case DecimalType() =>
- val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal
+ val hiveDecimal = from.getDecimal(ordinal)
to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal)))
case LongType =>
to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal)))
diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
index bec9d9aca3..eaf7a1ddd4 100644
--- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
+++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
@@ -94,7 +94,7 @@ private[hive] class SparkExecuteStatementOperation(
case FloatType =>
to += from.getFloat(ordinal)
case DecimalType() =>
- to += from.getAs[BigDecimal](ordinal).bigDecimal
+ to += from.getDecimal(ordinal)
case LongType =>
to += from.getLong(ordinal)
case ByteType =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 4246b8b091..10833c1132 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -409,8 +409,9 @@ private object HiveContext {
case (d: Date, DateType) => new DateWritable(d).toString
case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString
case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8")
- case (decimal: BigDecimal, DecimalType()) => // Hive strips trailing zeros so use its toString
- HiveShim.createDecimal(decimal.underlying()).toString
+ case (decimal: java.math.BigDecimal, DecimalType()) =>
+ // Hive strips trailing zeros so use its toString
+ HiveShim.createDecimal(decimal).toString
case (other, tpe) if primitiveTypes contains tpe => other.toString
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 5140d2064c..d87c4945c8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -341,7 +341,7 @@ private[hive] trait HiveInspectors {
(o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size)
case _: JavaHiveDecimalObjectInspector =>
- (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toBigDecimal.underlying())
+ (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal)
case soi: StandardStructObjectInspector =>
val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector))
@@ -412,7 +412,7 @@ private[hive] trait HiveInspectors {
case _: HiveDecimalObjectInspector if x.preferWritable() =>
HiveShim.getDecimalWritable(a.asInstanceOf[Decimal])
case _: HiveDecimalObjectInspector =>
- HiveShim.createDecimal(a.asInstanceOf[Decimal].toBigDecimal.underlying())
+ HiveShim.createDecimal(a.asInstanceOf[Decimal].toJavaBigDecimal)
case _: BinaryObjectInspector if x.preferWritable() => HiveShim.getBinaryWritable(a)
case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]]
case _: DateObjectInspector if x.preferWritable() => HiveShim.getDateWritable(a)
diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
index 58417a15bb..c0b7741bc3 100644
--- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
+++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
@@ -174,7 +174,7 @@ private[hive] object HiveShim {
null
} else {
new hiveIo.HiveDecimalWritable(
- HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying()))
+ HiveShim.createDecimal(value.asInstanceOf[Decimal].toJavaBigDecimal))
}
def getPrimitiveNullWritable: NullWritable = NullWritable.get()
diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
index 1f768ca971..c04cda7bf1 100644
--- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
+++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
@@ -276,7 +276,7 @@ private[hive] object HiveShim {
} else {
// TODO precise, scale?
new hiveIo.HiveDecimalWritable(
- HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying()))
+ HiveShim.createDecimal(value.asInstanceOf[Decimal].toJavaBigDecimal))
}
def getPrimitiveNullWritable: NullWritable = NullWritable.get()