From 47f8829e0504883f2822195077afee6857ce0b16 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 25 Jun 2014 23:55:31 -0700 Subject: [SPARK-2254] [SQL] ScalaRefection should mark primitive types as non-nullable. Author: Takuya UESHIN Closes #1193 from ueshin/issues/SPARK-2254 and squashes the following commits: cfd6088 [Takuya UESHIN] Modify ScalaRefection.schemaFor method to return nullability of Scala Type. (cherry picked from commit e4899a253728bfa7c78709a37a4837f74b72bd61) Signed-off-by: Reynold Xin --- .../spark/sql/catalyst/ScalaReflection.scala | 65 +++++----- .../spark/sql/catalyst/ScalaReflectionSuite.scala | 131 +++++++++++++++++++++ 2 files changed, 165 insertions(+), 31 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 196695a0a1..ada48eaf5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -30,53 +30,56 @@ import org.apache.spark.sql.catalyst.types._ object ScalaReflection { import scala.reflect.runtime.universe._ + case class Schema(dataType: DataType, nullable: Boolean) + /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { - case s: StructType => - s.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)()) + case Schema(s: StructType, _) => + s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) } - /** Returns a catalyst DataType for the given Scala Type using reflection. */ - def schemaFor[T: TypeTag]: DataType = schemaFor(typeOf[T]) + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T]) - /** Returns a catalyst DataType for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`): DataType = tpe match { + /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ + def schemaFor(tpe: `Type`): Schema = tpe match { case t if t <:< typeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - schemaFor(optType) + Schema(schemaFor(optType).dataType, nullable = true) case t if t <:< typeOf[Product] => val params = t.member("": TermName).asMethod.paramss - StructType( - params.head.map(p => - StructField(p.name.toString, schemaFor(p.typeSignature), nullable = true))) + Schema(StructType( + params.head.map { p => + val Schema(dataType, nullable) = schemaFor(p.typeSignature) + StructField(p.name.toString, dataType, nullable) + }), nullable = true) // Need to decide if we actually need a special type here. - case t if t <:< typeOf[Array[Byte]] => BinaryType + case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) case t if t <:< typeOf[Array[_]] => sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") case t if t <:< typeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - ArrayType(schemaFor(elementType)) + Schema(ArrayType(schemaFor(elementType).dataType), nullable = true) case t if t <:< typeOf[Map[_,_]] => val TypeRef(_, _, Seq(keyType, valueType)) = t - MapType(schemaFor(keyType), schemaFor(valueType)) - case t if t <:< typeOf[String] => StringType - case t if t <:< typeOf[Timestamp] => TimestampType - case t if t <:< typeOf[BigDecimal] => DecimalType - case t if t <:< typeOf[java.lang.Integer] => IntegerType - case t if t <:< typeOf[java.lang.Long] => LongType - case t if t <:< typeOf[java.lang.Double] => DoubleType - case t if t <:< typeOf[java.lang.Float] => FloatType - case t if t <:< typeOf[java.lang.Short] => ShortType - case t if t <:< typeOf[java.lang.Byte] => ByteType - case t if t <:< typeOf[java.lang.Boolean] => BooleanType - // TODO: The following datatypes could be marked as non-nullable. - case t if t <:< definitions.IntTpe => IntegerType - case t if t <:< definitions.LongTpe => LongType - case t if t <:< definitions.DoubleTpe => DoubleType - case t if t <:< definitions.FloatTpe => FloatType - case t if t <:< definitions.ShortTpe => ShortType - case t if t <:< definitions.ByteTpe => ByteType - case t if t <:< definitions.BooleanTpe => BooleanType + Schema(MapType(schemaFor(keyType).dataType, schemaFor(valueType).dataType), nullable = true) + case t if t <:< typeOf[String] => Schema(StringType, nullable = true) + case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true) + case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) + case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) + case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) + case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) + case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) + case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) + case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) + case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) + case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) + case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) + case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) + case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) + case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) + case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) } implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala new file mode 100644 index 0000000000..489d7e9c24 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import java.sql.Timestamp + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ + +case class PrimitiveData( + intField: Int, + longField: Long, + doubleField: Double, + floatField: Float, + shortField: Short, + byteField: Byte, + booleanField: Boolean) + +case class NullableData( + intField: java.lang.Integer, + longField: java.lang.Long, + doubleField: java.lang.Double, + floatField: java.lang.Float, + shortField: java.lang.Short, + byteField: java.lang.Byte, + booleanField: java.lang.Boolean, + stringField: String, + decimalField: BigDecimal, + timestampField: Timestamp, + binaryField: Array[Byte]) + +case class OptionalData( + intField: Option[Int], + longField: Option[Long], + doubleField: Option[Double], + floatField: Option[Float], + shortField: Option[Short], + byteField: Option[Byte], + booleanField: Option[Boolean]) + +case class ComplexData( + arrayField: Seq[Int], + mapField: Map[Int, String], + structField: PrimitiveData) + +class ScalaReflectionSuite extends FunSuite { + import ScalaReflection._ + + test("primitive data") { + val schema = schemaFor[PrimitiveData] + assert(schema === Schema( + StructType(Seq( + StructField("intField", IntegerType, nullable = false), + StructField("longField", LongType, nullable = false), + StructField("doubleField", DoubleType, nullable = false), + StructField("floatField", FloatType, nullable = false), + StructField("shortField", ShortType, nullable = false), + StructField("byteField", ByteType, nullable = false), + StructField("booleanField", BooleanType, nullable = false))), + nullable = true)) + } + + test("nullable data") { + val schema = schemaFor[NullableData] + assert(schema === Schema( + StructType(Seq( + StructField("intField", IntegerType, nullable = true), + StructField("longField", LongType, nullable = true), + StructField("doubleField", DoubleType, nullable = true), + StructField("floatField", FloatType, nullable = true), + StructField("shortField", ShortType, nullable = true), + StructField("byteField", ByteType, nullable = true), + StructField("booleanField", BooleanType, nullable = true), + StructField("stringField", StringType, nullable = true), + StructField("decimalField", DecimalType, nullable = true), + StructField("timestampField", TimestampType, nullable = true), + StructField("binaryField", BinaryType, nullable = true))), + nullable = true)) + } + + test("optinal data") { + val schema = schemaFor[OptionalData] + assert(schema === Schema( + StructType(Seq( + StructField("intField", IntegerType, nullable = true), + StructField("longField", LongType, nullable = true), + StructField("doubleField", DoubleType, nullable = true), + StructField("floatField", FloatType, nullable = true), + StructField("shortField", ShortType, nullable = true), + StructField("byteField", ByteType, nullable = true), + StructField("booleanField", BooleanType, nullable = true))), + nullable = true)) + } + + test("complex data") { + val schema = schemaFor[ComplexData] + assert(schema === Schema( + StructType(Seq( + StructField("arrayField", ArrayType(IntegerType), nullable = true), + StructField("mapField", MapType(IntegerType, StringType), nullable = true), + StructField( + "structField", + StructType(Seq( + StructField("intField", IntegerType, nullable = false), + StructField("longField", LongType, nullable = false), + StructField("doubleField", DoubleType, nullable = false), + StructField("floatField", FloatType, nullable = false), + StructField("shortField", ShortType, nullable = false), + StructField("byteField", ByteType, nullable = false), + StructField("booleanField", BooleanType, nullable = false))), + nullable = true))), + nullable = true)) + } +} -- cgit v1.2.3