diff options
Diffstat (limited to 'sql/core')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala | 11 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala | 12 |
2 files changed, 14 insertions, 9 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 374af48b82..4abda21ffe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -204,14 +204,6 @@ case class Sort( */ @DeveloperApi object ExistingRdd { - def convertToCatalyst(a: Any): Any = a match { - case o: Option[_] => o.orNull - case s: Seq[_] => s.map(convertToCatalyst) - case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case other => other - } - def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { data.mapPartitions { iterator => if (iterator.isEmpty) { @@ -223,7 +215,7 @@ object ExistingRdd { bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = convertToCatalyst(r.productElement(i)) + mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) i += 1 } @@ -245,6 +237,7 @@ object ExistingRdd { case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { override def execute() = rdd } + /** * :: DeveloperApi :: * Computes the set of distinct input rows using a HashSet. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 76aa9b0081..ef9b76b1e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ +case class FunctionResult(f1: String, f2: String) + class UDFSuite extends QueryTest { test("Simple UDF") { @@ -33,4 +35,14 @@ class UDFSuite extends QueryTest { registerFunction("strLenScala", (_: String).length + (_:Int)) assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) } + + + test("struct UDF") { + registerFunction("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + + val result= + sql("SELECT returnStruct('test', 'test2') as ret") + .select("ret.f1".attr).first().getString(0) + assert(result == "test") + } } |