diff options
author | Kevin Yu <qyu@us.ibm.com> | 2016-05-20 12:41:14 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-05-20 12:41:14 +0800 |
commit | 17591d90e6873f30a042112f56a1686726ccbd60 (patch) | |
tree | d155359ab9626077375c3531500475a294846416 /sql | |
parent | d5c47f8ff8c09ff017e896835db044661ee60909 (diff) | |
download | spark-17591d90e6873f30a042112f56a1686726ccbd60.tar.gz spark-17591d90e6873f30a042112f56a1686726ccbd60.tar.bz2 spark-17591d90e6873f30a042112f56a1686726ccbd60.zip |
[SPARK-11827][SQL] Adding java.math.BigInteger support in Java type inference for POJOs and Java collections
Hello : Can you help check this PR? I am adding support for the java.math.BigInteger for java bean code path. I saw internally spark is converting the BigInteger to BigDecimal in ColumnType.scala and CatalystRowConverter.scala. I use the similar way and convert the BigInteger to the BigDecimal. .
Author: Kevin Yu <qyu@us.ibm.com>
Closes #10125 from kevinyu98/working_on_spark-11827.
Diffstat (limited to 'sql')
8 files changed, 76 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 9bfc381639..9cc7b2ac79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import java.lang.{Iterable => JavaIterable} import java.math.{BigDecimal => JavaBigDecimal} +import java.math.{BigInteger => JavaBigInteger} import java.sql.{Date, Timestamp} import java.util.{Map => JavaMap} import javax.annotation.Nullable @@ -326,6 +327,7 @@ object CatalystTypeConverters { val decimal = scalaValue match { case d: BigDecimal => Decimal(d) case d: JavaBigDecimal => Decimal(d) + case d: JavaBigInteger => Decimal(d) case d: Decimal => d } if (decimal.changePrecision(dataType.precision, dataType.scale)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 690758205e..1fe143494a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -89,6 +89,7 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true) + case c: Class[_] if c == classOf[java.math.BigInteger] => (DecimalType.BigIntDecimal, true) case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c0fa220d34..58df651da2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -259,6 +259,12 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[BigDecimal] => Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + case t if t <:< localTypeOf[java.math.BigInteger] => + Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger])) + + case t if t <:< localTypeOf[scala.math.BigInt] => + Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt])) + case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t @@ -592,6 +598,20 @@ object ScalaReflection extends ScalaReflection { "apply", inputObject :: Nil) + case t if t <:< localTypeOf[java.math.BigInteger] => + StaticInvoke( + Decimal.getClass, + DecimalType.BigIntDecimal, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[scala.math.BigInt] => + StaticInvoke( + Decimal.getClass, + DecimalType.BigIntDecimal, + "apply", + inputObject :: Nil) + case t if t <:< localTypeOf[java.lang.Integer] => Invoke(inputObject, "intValue", IntegerType) case t if t <:< localTypeOf[java.lang.Long] => @@ -736,6 +756,10 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.math.BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) + case t if t <:< localTypeOf[java.math.BigInteger] => + Schema(DecimalType.BigIntDecimal, nullable = true) + case t if t <:< localTypeOf[scala.math.BigInt] => + Schema(DecimalType.BigIntDecimal, nullable = true) case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 2f7422b742..b907f62802 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import java.math.{MathContext, RoundingMode} +import java.math.{BigInteger, MathContext, RoundingMode} import org.apache.spark.annotation.DeveloperApi @@ -129,6 +129,23 @@ final class Decimal extends Ordered[Decimal] with Serializable { } /** + * Set this Decimal to the given BigInteger value. Will have precision 38 and scale 0. + */ + def set(bigintval: BigInteger): Decimal = { + try { + this.decimalVal = null + this.longVal = bigintval.longValueExact() + this._precision = DecimalType.MAX_PRECISION + this._scale = 0 + this + } + catch { + case e: ArithmeticException => + throw new IllegalArgumentException(s"BigInteger ${bigintval} too large for decimal") + } + } + + /** * Set this Decimal to the given Decimal value. */ def set(decimal: Decimal): Decimal = { @@ -155,6 +172,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } + def toScalaBigInt: BigInt = BigInt(toLong) + + def toJavaBigInteger: java.math.BigInteger = java.math.BigInteger.valueOf(toLong) + def toUnscaledLong: Long = { if (decimalVal.ne(null)) { decimalVal.underlying().unscaledValue().longValue() @@ -371,6 +392,10 @@ object Decimal { def apply(value: java.math.BigDecimal): Decimal = new Decimal().set(value) + def apply(value: java.math.BigInteger): Decimal = new Decimal().set(value) + + def apply(value: scala.math.BigInt): Decimal = new Decimal().set(value.bigInteger) + def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = new Decimal().set(value, precision, scale) @@ -387,6 +412,8 @@ object Decimal { value match { case j: java.math.BigDecimal => apply(j) case d: BigDecimal => apply(d) + case k: scala.math.BigInt => apply(k) + case l: java.math.BigInteger => apply(l) case d: Decimal => d } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 9c1319c1c5..6b7e3714e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -117,6 +117,7 @@ object DecimalType extends AbstractDataType { private[sql] val LongDecimal = DecimalType(20, 0) private[sql] val FloatDecimal = DecimalType(14, 7) private[sql] val DoubleDecimal = DecimalType(30, 15) + private[sql] val BigIntDecimal = DecimalType(38, 0) private[sql] def forType(dataType: DataType): DecimalType = dataType match { case ByteType => ByteDecimal diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 227e835e7e..d4387890b4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.encoders +import java.math.BigInteger import java.sql.{Date, Timestamp} import java.util.Arrays @@ -109,7 +110,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal") encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") - + encodeDecodeTest(BigInt("23134123123"), "scala biginteger") + encodeDecodeTest(new BigInteger("23134123123"), "java BigInteger") encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal") encodeDecodeTest("hello", "string") diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 324ebbae38..35a9f44fec 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -21,6 +21,8 @@ import java.io.Serializable; import java.net.URISyntaxException; import java.net.URL; import java.util.*; +import java.math.BigInteger; +import java.math.BigDecimal; import scala.collection.JavaConverters; import scala.collection.Seq; @@ -130,6 +132,7 @@ public class JavaDataFrameSuite { private Integer[] b = { 0, 1 }; private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 }); private List<String> d = Arrays.asList("floppy", "disk"); + private BigInteger e = new BigInteger("1234567"); public double getA() { return a; @@ -146,6 +149,8 @@ public class JavaDataFrameSuite { public List<String> getD() { return d; } + + public BigInteger getE() { return e; } } void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) { @@ -163,7 +168,9 @@ public class JavaDataFrameSuite { Assert.assertEquals( new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()), schema.apply("d")); - Row first = df.select("a", "b", "c", "d").first(); + Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, Metadata.empty()), + schema.apply("e")); + Row first = df.select("a", "b", "c", "d", "e").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. @@ -182,6 +189,8 @@ public class JavaDataFrameSuite { for (int i = 0; i < d.length(); i++) { Assert.assertEquals(bean.getD().get(i), d.apply(i)); } + // Java.math.BigInteger is equavient to Spark Decimal(38,0) + Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4)); } @Test diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 491bdb3ef9..c9bd05d0e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -34,7 +34,9 @@ case class ReflectData( decimalField: java.math.BigDecimal, date: Date, timestampField: Timestamp, - seqInt: Seq[Int]) + seqInt: Seq[Int], + javaBigInt: java.math.BigInteger, + scalaBigInt: scala.math.BigInt) case class NullReflectData( intField: java.lang.Integer, @@ -77,13 +79,15 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3)) + new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3), + new java.math.BigInteger("1"), scala.math.BigInt(1)) Seq(data).toDF().createOrReplaceTempView("reflectData") assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), - new Timestamp(12345), Seq(1, 2, 3))) + new Timestamp(12345), Seq(1, 2, 3), new java.math.BigDecimal(1), + new java.math.BigDecimal(1))) } test("query case class RDD with nulls") { |