aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-04-02 18:14:31 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-04-02 18:14:31 -0700
commit47ebea5468df2e4f94ef493c5403fcdcda8c5eb2 (patch)
tree3d6957ea93d01a3b4164f0bcc64932c5b375649c /sql
parent9c65fa76f9d413e311a80f29d35d3ff7722e9476 (diff)
downloadspark-47ebea5468df2e4f94ef493c5403fcdcda8c5eb2.tar.gz
spark-47ebea5468df2e4f94ef493c5403fcdcda8c5eb2.tar.bz2
spark-47ebea5468df2e4f94ef493c5403fcdcda8c5eb2.zip
[SQL] SPARK-1364 Improve datatype and test coverage for ScalaReflection schema inference.
Author: Michael Armbrust <michael@databricks.com> Closes #293 from marmbrus/reflectTypes and squashes the following commits: f54e8e8 [Michael Armbrust] Improve datatype and test coverage for ScalaReflection schema inference.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala56
2 files changed, 66 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 976dda8d7e..5aaa63bf3b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -43,15 +43,25 @@ object ScalaReflection {
val params = t.member("<init>": TermName).asMethod.paramss
StructType(
params.head.map(p => StructField(p.name.toString, schemaFor(p.typeSignature), true)))
+ // Need to decide if we actually need a special type here.
+ case t if t <:< typeOf[Array[Byte]] => BinaryType
+ case t if t <:< typeOf[Array[_]] =>
+ sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
ArrayType(schemaFor(elementType))
+ case t if t <:< typeOf[Map[_,_]] =>
+ val TypeRef(_, _, Seq(keyType, valueType)) = t
+ MapType(schemaFor(keyType), schemaFor(valueType))
case t if t <:< typeOf[String] => StringType
case t if t <:< definitions.IntTpe => IntegerType
case t if t <:< definitions.LongTpe => LongType
+ case t if t <:< definitions.FloatTpe => FloatType
case t if t <:< definitions.DoubleTpe => DoubleType
case t if t <:< definitions.ShortTpe => ShortType
case t if t <:< definitions.ByteTpe => ByteType
+ case t if t <:< definitions.BooleanTpe => BooleanType
+ case t if t <:< typeOf[BigDecimal] => DecimalType
}
implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
new file mode 100644
index 0000000000..70033a050c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.test.TestSQLContext._
+
+case class ReflectData(
+ stringField: String,
+ intField: Int,
+ longField: Long,
+ floatField: Float,
+ doubleField: Double,
+ shortField: Short,
+ byteField: Byte,
+ booleanField: Boolean,
+ decimalField: BigDecimal,
+ seqInt: Seq[Int])
+
+case class ReflectBinary(data: Array[Byte])
+
+class ScalaReflectionRelationSuite extends FunSuite {
+ test("query case class RDD") {
+ val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
+ BigDecimal(1), Seq(1,2,3))
+ val rdd = sparkContext.parallelize(data :: Nil)
+ rdd.registerAsTable("reflectData")
+
+ assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq)
+ }
+
+ // Equality is broken for Arrays, so we test that separately.
+ test("query binary data") {
+ val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil)
+ rdd.registerAsTable("reflectBinary")
+
+ val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]]
+ assert(result.toSeq === Seq[Byte](1))
+ }
+} \ No newline at end of file