aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-08-28 00:15:23 -0700
committerMichael Armbrust <michael@databricks.com>2014-08-28 00:15:23 -0700
commit76e3ba4264c4a0bc2c33ae6ac862fc40bc302d83 (patch)
treeb7b8a27d9eb5d530ce484db1040ea9e9cd742e30 /sql
parent68f75dcdfe7e8ab229b73824692c4b3d4c39946c (diff)
downloadspark-76e3ba4264c4a0bc2c33ae6ac862fc40bc302d83.tar.gz
spark-76e3ba4264c4a0bc2c33ae6ac862fc40bc302d83.tar.bz2
spark-76e3ba4264c4a0bc2c33ae6ac862fc40bc302d83.zip
[SPARK-3230][SQL] Fix udfs that return structs
We need to convert the case classes into Rows. Author: Michael Armbrust <michael@databricks.com> Closes #2133 from marmbrus/structUdfs and squashes the following commits: 189722f [Michael Armbrust] Merge remote-tracking branch 'origin/master' into structUdfs 8e29b1c [Michael Armbrust] Use existing function d8d0b76 [Michael Armbrust] Fix udfs that return structs
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala12
4 files changed, 30 insertions, 12 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 6b6b636cd9..88a8fa7c28 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst
import java.sql.Timestamp
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.types._
@@ -32,6 +31,15 @@ object ScalaReflection {
case class Schema(dataType: DataType, nullable: Boolean)
+ /** Converts Scala objects to catalyst rows / types */
+ def convertToCatalyst(a: Any): Any = a match {
+ case o: Option[_] => o.orNull
+ case s: Seq[_] => s.map(convertToCatalyst)
+ case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
+ case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
+ case other => other
+ }
+
/** Returns a Sequence of attributes for the given case class type. */
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
case Schema(s: StructType, _) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index 589816ccec..1b687a443e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.types.DataType
import org.apache.spark.util.ClosureCleaner
@@ -27,6 +28,8 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
def nullable = true
+ override def toString = s"scalaUDF(${children.mkString(",")})"
+
/** This method has been generated by this script
(1 to 22).map { x =>
@@ -44,7 +47,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
// scalastyle:off
override def eval(input: Row): Any = {
- children.size match {
+ val result = children.size match {
case 0 => function.asInstanceOf[() => Any]()
case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input))
case 2 =>
@@ -343,5 +346,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
children(21).eval(input))
}
// scalastyle:on
+
+ ScalaReflection.convertToCatalyst(result)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 374af48b82..4abda21ffe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -204,14 +204,6 @@ case class Sort(
*/
@DeveloperApi
object ExistingRdd {
- def convertToCatalyst(a: Any): Any = a match {
- case o: Option[_] => o.orNull
- case s: Seq[_] => s.map(convertToCatalyst)
- case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
- case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
- case other => other
- }
-
def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
data.mapPartitions { iterator =>
if (iterator.isEmpty) {
@@ -223,7 +215,7 @@ object ExistingRdd {
bufferedIterator.map { r =>
var i = 0
while (i < mutableRow.length) {
- mutableRow(i) = convertToCatalyst(r.productElement(i))
+ mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i))
i += 1
}
@@ -245,6 +237,7 @@ object ExistingRdd {
case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
override def execute() = rdd
}
+
/**
* :: DeveloperApi ::
* Computes the set of distinct input rows using a HashSet.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 76aa9b0081..ef9b76b1e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -22,6 +22,8 @@ import org.apache.spark.sql.test._
/* Implicits */
import TestSQLContext._
+case class FunctionResult(f1: String, f2: String)
+
class UDFSuite extends QueryTest {
test("Simple UDF") {
@@ -33,4 +35,14 @@ class UDFSuite extends QueryTest {
registerFunction("strLenScala", (_: String).length + (_:Int))
assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5)
}
+
+
+ test("struct UDF") {
+ registerFunction("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
+
+ val result=
+ sql("SELECT returnStruct('test', 'test2') as ret")
+ .select("ret.f1".attr).first().getString(0)
+ assert(result == "test")
+ }
}