diff options
author | Michael Armbrust <michael@databricks.com> | 2014-08-28 00:15:23 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-08-28 00:15:23 -0700 |
commit | 76e3ba4264c4a0bc2c33ae6ac862fc40bc302d83 (patch) | |
tree | b7b8a27d9eb5d530ce484db1040ea9e9cd742e30 /sql/core | |
parent | 68f75dcdfe7e8ab229b73824692c4b3d4c39946c (diff) | |
download | spark-76e3ba4264c4a0bc2c33ae6ac862fc40bc302d83.tar.gz spark-76e3ba4264c4a0bc2c33ae6ac862fc40bc302d83.tar.bz2 spark-76e3ba4264c4a0bc2c33ae6ac862fc40bc302d83.zip |
[SPARK-3230][SQL] Fix udfs that return structs
We need to convert the case classes into Rows.
Author: Michael Armbrust <michael@databricks.com>
Closes #2133 from marmbrus/structUdfs and squashes the following commits:
189722f [Michael Armbrust] Merge remote-tracking branch 'origin/master' into structUdfs
8e29b1c [Michael Armbrust] Use existing function
d8d0b76 [Michael Armbrust] Fix udfs that return structs
Diffstat (limited to 'sql/core')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala | 11 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala | 12 |
2 files changed, 14 insertions, 9 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 374af48b82..4abda21ffe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -204,14 +204,6 @@ case class Sort( */ @DeveloperApi object ExistingRdd { - def convertToCatalyst(a: Any): Any = a match { - case o: Option[_] => o.orNull - case s: Seq[_] => s.map(convertToCatalyst) - case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case other => other - } - def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { data.mapPartitions { iterator => if (iterator.isEmpty) { @@ -223,7 +215,7 @@ object ExistingRdd { bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = convertToCatalyst(r.productElement(i)) + mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) i += 1 } @@ -245,6 +237,7 @@ object ExistingRdd { case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { override def execute() = rdd } + /** * :: DeveloperApi :: * Computes the set of distinct input rows using a HashSet. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 76aa9b0081..ef9b76b1e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ +case class FunctionResult(f1: String, f2: String) + class UDFSuite extends QueryTest { test("Simple UDF") { @@ -33,4 +35,14 @@ class UDFSuite extends QueryTest { registerFunction("strLenScala", (_: String).length + (_:Int)) assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) } + + + test("struct UDF") { + registerFunction("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + + val result= + sql("SELECT returnStruct('test', 'test2') as ret") + .select("ret.f1".attr).first().getString(0) + assert(result == "test") + } } |