aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-05-05 22:59:42 -0700
committerMatei Zaharia <matei@databricks.com>2014-05-05 22:59:42 -0700
commit3c64750bdd4c2d0a5562f90aead37be81627cc9d (patch)
treea163bcb9f71d50f21ba3f233418faa55305ebfb0
parenta2262cdb7aa30e9f45043f1440d4b02bc3340f9f (diff)
downloadspark-3c64750bdd4c2d0a5562f90aead37be81627cc9d.tar.gz
spark-3c64750bdd4c2d0a5562f90aead37be81627cc9d.tar.bz2
spark-3c64750bdd4c2d0a5562f90aead37be81627cc9d.zip
[SQL] SPARK-1732 - Support for null primitive values.
I also removed a println that I bumped into. Author: Michael Armbrust <michael@databricks.com> Closes #658 from marmbrus/nullPrimitives and squashes the following commits: a3ec4f3 [Michael Armbrust] Remove println. 695606b [Michael Armbrust] Support for null primatives from using scala and java reflection.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala61
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala5
7 files changed, 122 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 446d0e0bd7..792ef6cee6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -44,7 +44,8 @@ object ScalaReflection {
case t if t <:< typeOf[Product] =>
val params = t.member("<init>": TermName).asMethod.paramss
StructType(
- params.head.map(p => StructField(p.name.toString, schemaFor(p.typeSignature), true)))
+ params.head.map(p =>
+ StructField(p.name.toString, schemaFor(p.typeSignature), nullable = true)))
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => BinaryType
case t if t <:< typeOf[Array[_]] =>
@@ -58,6 +59,17 @@ object ScalaReflection {
case t if t <:< typeOf[String] => StringType
case t if t <:< typeOf[Timestamp] => TimestampType
case t if t <:< typeOf[BigDecimal] => DecimalType
+ case t if t <:< typeOf[Option[_]] =>
+ val TypeRef(_, _, Seq(optType)) = t
+ schemaFor(optType)
+ case t if t <:< typeOf[java.lang.Integer] => IntegerType
+ case t if t <:< typeOf[java.lang.Long] => LongType
+ case t if t <:< typeOf[java.lang.Double] => DoubleType
+ case t if t <:< typeOf[java.lang.Float] => FloatType
+ case t if t <:< typeOf[java.lang.Short] => ShortType
+ case t if t <:< typeOf[java.lang.Byte] => ByteType
+ case t if t <:< typeOf[java.lang.Boolean] => BooleanType
+ // TODO: The following datatypes could be marked as non-nullable.
case t if t <:< definitions.IntTpe => IntegerType
case t if t <:< definitions.LongTpe => LongType
case t if t <:< definitions.DoubleTpe => DoubleType
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
index a734708879..57facbe10f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
@@ -132,6 +132,14 @@ class JavaSQLContext(sparkContext: JavaSparkContext) {
case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
case c: Class[_] if c == java.lang.Float.TYPE => FloatType
case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
+
+ case c: Class[_] if c == classOf[java.lang.Short] => ShortType
+ case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
+ case c: Class[_] if c == classOf[java.lang.Long] => LongType
+ case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
+ case c: Class[_] if c == classOf[java.lang.Byte] => ByteType
+ case c: Class[_] if c == classOf[java.lang.Float] => FloatType
+ case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
}
// TODO: Nullability could be stricter.
AttributeReference(property.getName, dataType, nullable = true)()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
index 362fe76958..9b0dd21761 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow}
/**
* A result row from a SparkSQL query.
*/
-class Row(row: ScalaRow) extends Serializable {
+class Row(private[spark] val row: ScalaRow) extends Serializable {
/** Returns the number of columns present in this Row. */
def length: Int = row.length
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index d807187a5f..8969794c69 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -164,6 +164,7 @@ case class Sort(
@DeveloperApi
object ExistingRdd {
def convertToCatalyst(a: Any): Any = a match {
+ case o: Option[_] => o.orNull
case s: Seq[Any] => s.map(convertToCatalyst)
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
case other => other
@@ -180,7 +181,7 @@ object ExistingRdd {
bufferedIterator.map { r =>
var i = 0
while (i < mutableRow.length) {
- mutableRow(i) = r.productElement(i)
+ mutableRow(i) = convertToCatalyst(r.productElement(i))
i += 1
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index 1cbf973c34..f2934da9a0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -36,6 +36,24 @@ case class ReflectData(
timestampField: Timestamp,
seqInt: Seq[Int])
+case class NullReflectData(
+ intField: java.lang.Integer,
+ longField: java.lang.Long,
+ floatField: java.lang.Float,
+ doubleField: java.lang.Double,
+ shortField: java.lang.Short,
+ byteField: java.lang.Byte,
+ booleanField: java.lang.Boolean)
+
+case class OptionalReflectData(
+ intField: Option[Int],
+ longField: Option[Long],
+ floatField: Option[Float],
+ doubleField: Option[Double],
+ shortField: Option[Short],
+ byteField: Option[Byte],
+ booleanField: Option[Boolean])
+
case class ReflectBinary(data: Array[Byte])
class ScalaReflectionRelationSuite extends FunSuite {
@@ -48,6 +66,22 @@ class ScalaReflectionRelationSuite extends FunSuite {
assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq)
}
+ test("query case class RDD with nulls") {
+ val data = NullReflectData(null, null, null, null, null, null, null)
+ val rdd = sparkContext.parallelize(data :: Nil)
+ rdd.registerAsTable("reflectNullData")
+
+ assert(sql("SELECT * FROM reflectNullData").collect().head === Seq.fill(7)(null))
+ }
+
+ test("query case class RDD with Nones") {
+ val data = OptionalReflectData(None, None, None, None, None, None, None)
+ val rdd = sparkContext.parallelize(data :: Nil)
+ rdd.registerAsTable("reflectOptionalData")
+
+ assert(sql("SELECT * FROM reflectOptionalData").collect().head === Seq.fill(7)(null))
+ }
+
// Equality is broken for Arrays, so we test that separately.
test("query binary data") {
val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
index def0e046a3..9fff7222fe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
@@ -35,6 +35,17 @@ class PersonBean extends Serializable {
var age: Int = _
}
+class AllTypesBean extends Serializable {
+ @BeanProperty var stringField: String = _
+ @BeanProperty var intField: java.lang.Integer = _
+ @BeanProperty var longField: java.lang.Long = _
+ @BeanProperty var floatField: java.lang.Float = _
+ @BeanProperty var doubleField: java.lang.Double = _
+ @BeanProperty var shortField: java.lang.Short = _
+ @BeanProperty var byteField: java.lang.Byte = _
+ @BeanProperty var booleanField: java.lang.Boolean = _
+}
+
class JavaSQLSuite extends FunSuite {
val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext)
val javaSqlCtx = new JavaSQLContext(javaCtx)
@@ -50,4 +61,54 @@ class JavaSQLSuite extends FunSuite {
schemaRDD.registerAsTable("people")
javaSqlCtx.sql("SELECT * FROM people").collect()
}
+
+ test("all types in JavaBeans") {
+ val bean = new AllTypesBean
+ bean.setStringField("")
+ bean.setIntField(0)
+ bean.setLongField(0)
+ bean.setFloatField(0.0F)
+ bean.setDoubleField(0.0)
+ bean.setShortField(0.toShort)
+ bean.setByteField(0.toByte)
+ bean.setBooleanField(false)
+
+ val rdd = javaCtx.parallelize(bean :: Nil)
+ val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean])
+ schemaRDD.registerAsTable("allTypes")
+
+ assert(
+ javaSqlCtx.sql(
+ """
+ |SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField,
+ | booleanField
+ |FROM allTypes
+ """.stripMargin).collect.head.row ===
+ Seq("", 0, 0L, 0F, 0.0, 0.toShort, 0.toByte, false))
+ }
+
+ test("all types null in JavaBeans") {
+ val bean = new AllTypesBean
+ bean.setStringField(null)
+ bean.setIntField(null)
+ bean.setLongField(null)
+ bean.setFloatField(null)
+ bean.setDoubleField(null)
+ bean.setShortField(null)
+ bean.setByteField(null)
+ bean.setBooleanField(null)
+
+ val rdd = javaCtx.parallelize(bean :: Nil)
+ val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean])
+ schemaRDD.registerAsTable("allTypes")
+
+ assert(
+ javaSqlCtx.sql(
+ """
+ |SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField,
+ | booleanField
+ |FROM allTypes
+ """.stripMargin).collect.head.row ===
+ Seq.fill(8)(null))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 325173cf95..71be410567 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -21,11 +21,12 @@ import java.nio.ByteBuffer
import org.scalatest.FunSuite
+import org.apache.spark.sql.Logging
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.SparkSqlSerializer
-class ColumnTypeSuite extends FunSuite {
+class ColumnTypeSuite extends FunSuite with Logging {
val DEFAULT_BUFFER_SIZE = 512
test("defaultSize") {
@@ -163,7 +164,7 @@ class ColumnTypeSuite extends FunSuite {
buffer.rewind()
seq.foreach { expected =>
- println("buffer = " + buffer + ", expected = " + expected)
+ logger.info("buffer = " + buffer + ", expected = " + expected)
val extracted = columnType.extract(buffer)
assert(
expected === extracted,