aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2016-02-02 10:15:40 -0800
committerMichael Armbrust <michael@databricks.com>2016-02-02 10:15:40 -0800
commit29d92181d0c49988c387d34e4a71b1afe02c29e2 (patch)
treeccda4acc62f202a594462d1164f85bb31dc6a04b /sql
parent12a20c144f14e80ef120ddcfb0b455a805a2da23 (diff)
downloadspark-29d92181d0c49988c387d34e4a71b1afe02c29e2.tar.gz
spark-29d92181d0c49988c387d34e4a71b1afe02c29e2.tar.bz2
spark-29d92181d0c49988c387d34e4a71b1afe02c29e2.zip
[SPARK-13094][SQL] Add encoders for seq/array of primitives
Author: Michael Armbrust <michael@databricks.com> Closes #11014 from marmbrus/seqEncoders.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala63
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala8
3 files changed, 91 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index ab414799f1..16c4095db7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -39,6 +39,8 @@ abstract class SQLImplicits {
/** @since 1.6.0 */
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder()
+ // Primitives
+
/** @since 1.6.0 */
implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder()
@@ -56,13 +58,72 @@ abstract class SQLImplicits {
/** @since 1.6.0 */
implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder()
- /** @since 1.6.0 */
+ /** @since 1.6.0 */
implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder()
/** @since 1.6.0 */
implicit def newStringEncoder: Encoder[String] = ExpressionEncoder()
+ // Seqs
+
+ /** @since 1.6.1 */
+ implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()
+
+ // Arrays
+
+ /** @since 1.6.1 */
+ implicit def newIntArrayEncoder: Encoder[Array[Int]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newLongArrayEncoder: Encoder[Array[Long]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newDoubleArrayEncoder: Encoder[Array[Double]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newFloatArrayEncoder: Encoder[Array[Float]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newByteArrayEncoder: Encoder[Array[Byte]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newShortArrayEncoder: Encoder[Array[Short]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newBooleanArrayEncoder: Encoder[Array[Boolean]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newStringArrayEncoder: Encoder[Array[String]] = ExpressionEncoder()
+
+ /** @since 1.6.1 */
+ implicit def newProductArrayEncoder[A <: Product : TypeTag]: Encoder[Array[A]] =
+ ExpressionEncoder()
+
/**
* Creates a [[Dataset]] from an RDD.
* @since 1.6.0
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index f75d096182..243d13b19d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -105,4 +105,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
agged,
"1", "abc", "3", "xyz", "5", "hello")
}
+
+ test("Arrays and Lists") {
+ checkAnswer(Seq(Seq(1)).toDS(), Seq(1))
+ checkAnswer(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong))
+ checkAnswer(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble))
+ checkAnswer(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat))
+ checkAnswer(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte))
+ checkAnswer(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort))
+ checkAnswer(Seq(Seq(true)).toDS(), Seq(true))
+ checkAnswer(Seq(Seq("test")).toDS(), Seq("test"))
+ checkAnswer(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1)))
+
+ checkAnswer(Seq(Array(1)).toDS(), Array(1))
+ checkAnswer(Seq(Array(1.toLong)).toDS(), Array(1.toLong))
+ checkAnswer(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble))
+ checkAnswer(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat))
+ checkAnswer(Seq(Array(1.toByte)).toDS(), Array(1.toByte))
+ checkAnswer(Seq(Array(1.toShort)).toDS(), Array(1.toShort))
+ checkAnswer(Seq(Array(true)).toDS(), Array(true))
+ checkAnswer(Seq(Array("test")).toDS(), Array("test"))
+ checkAnswer(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 405e5891ac..5401212428 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -95,7 +95,13 @@ abstract class QueryTest extends PlanTest {
""".stripMargin, e)
}
- if (decoded != expectedAnswer.toSet) {
+ // Handle the case where the return type is an array
+ val isArray = decoded.headOption.map(_.getClass.isArray).getOrElse(false)
+ def normalEquality = decoded == expectedAnswer.toSet
+ def expectedAsSeq = expectedAnswer.map(_.asInstanceOf[Array[_]].toSeq).toSet
+ def decodedAsSeq = decoded.map(_.asInstanceOf[Array[_]].toSeq)
+
+ if (!((isArray && expectedAsSeq == decodedAsSeq) || normalEquality)) {
val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted
val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted