aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDaoyuan Wang <daoyuan.wang@intel.com>2014-10-29 12:10:58 -0700
committerMichael Armbrust <michael@databricks.com>2014-10-29 12:10:58 -0700
commit353546766384b1e80fc8cc75c532d8d1821012b4 (patch)
tree233857976234633947dee2ddb08ab9252621de55 /sql
parentdff015533dd7b01b5e392f1ac5f3837e0a65f3f4 (diff)
downloadspark-353546766384b1e80fc8cc75c532d8d1821012b4.tar.gz
spark-353546766384b1e80fc8cc75c532d8d1821012b4.tar.bz2
spark-353546766384b1e80fc8cc75c532d8d1821012b4.zip
[SPARK-4003] [SQL] add 3 types for java SQL context
In JavaSqlContext, we need to let java program use big decimal, timestamp, date types. Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #2850 from adrian-wang/javacontext and squashes the following commits: 4c4292c [Daoyuan Wang] change underlying type of JavaSchemaRDD as scala bb0508f [Daoyuan Wang] add test cases 3c58b0d [Daoyuan Wang] add 3 types for java SQL context
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala41
3 files changed, 59 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
index f8171c3be3..082ae03eef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.sql.json.JsonRDD
+import org.apache.spark.sql.types.util.DataTypeConversions
import org.apache.spark.sql.{SQLContext, StructType => SStructType}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
import org.apache.spark.sql.parquet.ParquetRelation
@@ -97,7 +98,9 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
localBeanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod)
iter.map { row =>
- new GenericRow(extractors.map(e => e.invoke(row)).toArray[Any]): ScalaRow
+ new GenericRow(
+ extractors.map(e => DataTypeConversions.convertJavaToCatalyst(e.invoke(row))).toArray[Any]
+ ): ScalaRow
}
}
new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext))
@@ -226,6 +229,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
(org.apache.spark.sql.FloatType, true)
case c: Class[_] if c == classOf[java.lang.Boolean] =>
(org.apache.spark.sql.BooleanType, true)
+ case c: Class[_] if c == classOf[java.math.BigDecimal] =>
+ (org.apache.spark.sql.DecimalType, true)
+ case c: Class[_] if c == classOf[java.sql.Date] =>
+ (org.apache.spark.sql.DateType, true)
+ case c: Class[_] if c == classOf[java.sql.Timestamp] =>
+ (org.apache.spark.sql.TimestampType, true)
}
AttributeReference(property.getName, dataType, nullable)()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
index e44cb08309..609f7db562 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
@@ -110,4 +110,16 @@ protected[sql] object DataTypeConversions {
case structType: org.apache.spark.sql.api.java.StructType =>
StructType(structType.getFields.map(asScalaStructField))
}
+
+ /** Converts Java objects to catalyst rows / types */
+ def convertJavaToCatalyst(a: Any): Any = a match {
+ case d: java.math.BigDecimal => BigDecimal(d)
+ case other => other
+ }
+
+ /** Converts Java objects to catalyst rows / types */
+ def convertCatalystToJava(a: Any): Any = a match {
+ case d: scala.math.BigDecimal => d.underlying()
+ case other => other
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
index 203ff847e9..d83f3e23a9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
@@ -45,6 +45,9 @@ class AllTypesBean extends Serializable {
@BeanProperty var shortField: java.lang.Short = _
@BeanProperty var byteField: java.lang.Byte = _
@BeanProperty var booleanField: java.lang.Boolean = _
+ @BeanProperty var dateField: java.sql.Date = _
+ @BeanProperty var timestampField: java.sql.Timestamp = _
+ @BeanProperty var bigDecimalField: java.math.BigDecimal = _
}
class JavaSQLSuite extends FunSuite {
@@ -73,6 +76,9 @@ class JavaSQLSuite extends FunSuite {
bean.setShortField(0.toShort)
bean.setByteField(0.toByte)
bean.setBooleanField(false)
+ bean.setDateField(java.sql.Date.valueOf("2014-10-10"))
+ bean.setTimestampField(java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0"))
+ bean.setBigDecimalField(new java.math.BigDecimal(0))
val rdd = javaCtx.parallelize(bean :: Nil)
val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean])
@@ -82,10 +88,34 @@ class JavaSQLSuite extends FunSuite {
javaSqlCtx.sql(
"""
|SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField,
- | booleanField
+ | booleanField, dateField, timestampField, bigDecimalField
|FROM allTypes
""".stripMargin).collect.head.row ===
- Seq("", 0, 0L, 0F, 0.0, 0.toShort, 0.toByte, false))
+ Seq("", 0, 0L, 0F, 0.0, 0.toShort, 0.toByte, false, java.sql.Date.valueOf("2014-10-10"),
+ java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0"), scala.math.BigDecimal(0)))
+ }
+
+ test("decimal types in JavaBeans") {
+ val bean = new AllTypesBean
+ bean.setStringField("")
+ bean.setIntField(0)
+ bean.setLongField(0)
+ bean.setFloatField(0.0F)
+ bean.setDoubleField(0.0)
+ bean.setShortField(0.toShort)
+ bean.setByteField(0.toByte)
+ bean.setBooleanField(false)
+ bean.setDateField(java.sql.Date.valueOf("2014-10-10"))
+ bean.setTimestampField(java.sql.Timestamp.valueOf("2014-10-10 00:00:00.0"))
+ bean.setBigDecimalField(new java.math.BigDecimal(0))
+
+ val rdd = javaCtx.parallelize(bean :: Nil)
+ val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean])
+ schemaRDD.registerTempTable("decimalTypes")
+
+ assert(javaSqlCtx.sql(
+ "select bigDecimalField + bigDecimalField from decimalTypes"
+ ).collect.head.row === Seq(scala.math.BigDecimal(0)))
}
test("all types null in JavaBeans") {
@@ -98,6 +128,9 @@ class JavaSQLSuite extends FunSuite {
bean.setShortField(null)
bean.setByteField(null)
bean.setBooleanField(null)
+ bean.setDateField(null)
+ bean.setTimestampField(null)
+ bean.setBigDecimalField(null)
val rdd = javaCtx.parallelize(bean :: Nil)
val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean])
@@ -107,10 +140,10 @@ class JavaSQLSuite extends FunSuite {
javaSqlCtx.sql(
"""
|SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField,
- | booleanField
+ | booleanField, dateField, timestampField, bigDecimalField
|FROM allTypes
""".stripMargin).collect.head.row ===
- Seq.fill(8)(null))
+ Seq.fill(11)(null))
}
test("loads JSON datasets") {