aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorTakuya UESHIN <ueshin@happy-camper.st>2014-06-25 23:55:31 -0700
committerReynold Xin <rxin@apache.org>2014-06-25 23:55:38 -0700
commit47f8829e0504883f2822195077afee6857ce0b16 (patch)
treea389d55e303375b0677e8772cb373f4a6f3d6087 /sql
parentc445b3af3b6c524c6113dd036df0ed2f909d184c (diff)
downloadspark-47f8829e0504883f2822195077afee6857ce0b16.tar.gz
spark-47f8829e0504883f2822195077afee6857ce0b16.tar.bz2
spark-47f8829e0504883f2822195077afee6857ce0b16.zip
[SPARK-2254] [SQL] ScalaRefection should mark primitive types as non-nullable.
Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #1193 from ueshin/issues/SPARK-2254 and squashes the following commits: cfd6088 [Takuya UESHIN] Modify ScalaRefection.schemaFor method to return nullability of Scala Type. (cherry picked from commit e4899a253728bfa7c78709a37a4837f74b72bd61) Signed-off-by: Reynold Xin <rxin@apache.org>
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala65
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala131
2 files changed, 165 insertions, 31 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 196695a0a1..ada48eaf5d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -30,53 +30,56 @@ import org.apache.spark.sql.catalyst.types._
object ScalaReflection {
import scala.reflect.runtime.universe._
+ case class Schema(dataType: DataType, nullable: Boolean)
+
/** Returns a Sequence of attributes for the given case class type. */
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
- case s: StructType =>
- s.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)())
+ case Schema(s: StructType, _) =>
+ s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)())
}
- /** Returns a catalyst DataType for the given Scala Type using reflection. */
- def schemaFor[T: TypeTag]: DataType = schemaFor(typeOf[T])
+ /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
+ def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T])
- /** Returns a catalyst DataType for the given Scala Type using reflection. */
- def schemaFor(tpe: `Type`): DataType = tpe match {
+ /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
+ def schemaFor(tpe: `Type`): Schema = tpe match {
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
- schemaFor(optType)
+ Schema(schemaFor(optType).dataType, nullable = true)
case t if t <:< typeOf[Product] =>
val params = t.member("<init>": TermName).asMethod.paramss
- StructType(
- params.head.map(p =>
- StructField(p.name.toString, schemaFor(p.typeSignature), nullable = true)))
+ Schema(StructType(
+ params.head.map { p =>
+ val Schema(dataType, nullable) = schemaFor(p.typeSignature)
+ StructField(p.name.toString, dataType, nullable)
+ }), nullable = true)
// Need to decide if we actually need a special type here.
- case t if t <:< typeOf[Array[Byte]] => BinaryType
+ case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
- ArrayType(schemaFor(elementType))
+ Schema(ArrayType(schemaFor(elementType).dataType), nullable = true)
case t if t <:< typeOf[Map[_,_]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
- MapType(schemaFor(keyType), schemaFor(valueType))
- case t if t <:< typeOf[String] => StringType
- case t if t <:< typeOf[Timestamp] => TimestampType
- case t if t <:< typeOf[BigDecimal] => DecimalType
- case t if t <:< typeOf[java.lang.Integer] => IntegerType
- case t if t <:< typeOf[java.lang.Long] => LongType
- case t if t <:< typeOf[java.lang.Double] => DoubleType
- case t if t <:< typeOf[java.lang.Float] => FloatType
- case t if t <:< typeOf[java.lang.Short] => ShortType
- case t if t <:< typeOf[java.lang.Byte] => ByteType
- case t if t <:< typeOf[java.lang.Boolean] => BooleanType
- // TODO: The following datatypes could be marked as non-nullable.
- case t if t <:< definitions.IntTpe => IntegerType
- case t if t <:< definitions.LongTpe => LongType
- case t if t <:< definitions.DoubleTpe => DoubleType
- case t if t <:< definitions.FloatTpe => FloatType
- case t if t <:< definitions.ShortTpe => ShortType
- case t if t <:< definitions.ByteTpe => ByteType
- case t if t <:< definitions.BooleanTpe => BooleanType
+ Schema(MapType(schemaFor(keyType).dataType, schemaFor(valueType).dataType), nullable = true)
+ case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
+ case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
+ case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
+ case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
+ case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
+ case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
+ case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
+ case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
+ case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
+ case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
+ case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
+ case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
+ case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
+ case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
+ case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
+ case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
+ case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
}
implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
new file mode 100644
index 0000000000..489d7e9c24
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+import java.sql.Timestamp
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types._
+
+case class PrimitiveData(
+ intField: Int,
+ longField: Long,
+ doubleField: Double,
+ floatField: Float,
+ shortField: Short,
+ byteField: Byte,
+ booleanField: Boolean)
+
+case class NullableData(
+ intField: java.lang.Integer,
+ longField: java.lang.Long,
+ doubleField: java.lang.Double,
+ floatField: java.lang.Float,
+ shortField: java.lang.Short,
+ byteField: java.lang.Byte,
+ booleanField: java.lang.Boolean,
+ stringField: String,
+ decimalField: BigDecimal,
+ timestampField: Timestamp,
+ binaryField: Array[Byte])
+
+case class OptionalData(
+ intField: Option[Int],
+ longField: Option[Long],
+ doubleField: Option[Double],
+ floatField: Option[Float],
+ shortField: Option[Short],
+ byteField: Option[Byte],
+ booleanField: Option[Boolean])
+
+case class ComplexData(
+ arrayField: Seq[Int],
+ mapField: Map[Int, String],
+ structField: PrimitiveData)
+
+class ScalaReflectionSuite extends FunSuite {
+ import ScalaReflection._
+
+ test("primitive data") {
+ val schema = schemaFor[PrimitiveData]
+ assert(schema === Schema(
+ StructType(Seq(
+ StructField("intField", IntegerType, nullable = false),
+ StructField("longField", LongType, nullable = false),
+ StructField("doubleField", DoubleType, nullable = false),
+ StructField("floatField", FloatType, nullable = false),
+ StructField("shortField", ShortType, nullable = false),
+ StructField("byteField", ByteType, nullable = false),
+ StructField("booleanField", BooleanType, nullable = false))),
+ nullable = true))
+ }
+
+ test("nullable data") {
+ val schema = schemaFor[NullableData]
+ assert(schema === Schema(
+ StructType(Seq(
+ StructField("intField", IntegerType, nullable = true),
+ StructField("longField", LongType, nullable = true),
+ StructField("doubleField", DoubleType, nullable = true),
+ StructField("floatField", FloatType, nullable = true),
+ StructField("shortField", ShortType, nullable = true),
+ StructField("byteField", ByteType, nullable = true),
+ StructField("booleanField", BooleanType, nullable = true),
+ StructField("stringField", StringType, nullable = true),
+ StructField("decimalField", DecimalType, nullable = true),
+ StructField("timestampField", TimestampType, nullable = true),
+ StructField("binaryField", BinaryType, nullable = true))),
+ nullable = true))
+ }
+
+ test("optinal data") {
+ val schema = schemaFor[OptionalData]
+ assert(schema === Schema(
+ StructType(Seq(
+ StructField("intField", IntegerType, nullable = true),
+ StructField("longField", LongType, nullable = true),
+ StructField("doubleField", DoubleType, nullable = true),
+ StructField("floatField", FloatType, nullable = true),
+ StructField("shortField", ShortType, nullable = true),
+ StructField("byteField", ByteType, nullable = true),
+ StructField("booleanField", BooleanType, nullable = true))),
+ nullable = true))
+ }
+
+ test("complex data") {
+ val schema = schemaFor[ComplexData]
+ assert(schema === Schema(
+ StructType(Seq(
+ StructField("arrayField", ArrayType(IntegerType), nullable = true),
+ StructField("mapField", MapType(IntegerType, StringType), nullable = true),
+ StructField(
+ "structField",
+ StructType(Seq(
+ StructField("intField", IntegerType, nullable = false),
+ StructField("longField", LongType, nullable = false),
+ StructField("doubleField", DoubleType, nullable = false),
+ StructField("floatField", FloatType, nullable = false),
+ StructField("shortField", ShortType, nullable = false),
+ StructField("byteField", ByteType, nullable = false),
+ StructField("booleanField", BooleanType, nullable = false))),
+ nullable = true))),
+ nullable = true))
+ }
+}