From 76e3ba4264c4a0bc2c33ae6ac862fc40bc302d83 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 28 Aug 2014 00:15:23 -0700 Subject: [SPARK-3230][SQL] Fix udfs that return structs We need to convert the case classes into Rows. Author: Michael Armbrust Closes #2133 from marmbrus/structUdfs and squashes the following commits: 189722f [Michael Armbrust] Merge remote-tracking branch 'origin/master' into structUdfs 8e29b1c [Michael Armbrust] Use existing function d8d0b76 [Michael Armbrust] Fix udfs that return structs --- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 12 ++++++++++-- .../org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala | 7 ++++++- .../org/apache/spark/sql/execution/basicOperators.scala | 11 ++--------- sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala | 12 ++++++++++++ 4 files changed, 30 insertions(+), 12 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 6b6b636cd9..88a8fa7c28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst import java.sql.Timestamp -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types._ @@ -32,6 +31,15 @@ object ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) + /** Converts Scala objects to catalyst rows / types */ + def convertToCatalyst(a: Any): Any = a match { + case o: Option[_] => o.orNull + case s: Seq[_] => s.map(convertToCatalyst) + case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } + case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) + case other => other + } + /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 589816ccec..1b687a443e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.util.ClosureCleaner @@ -27,6 +28,8 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi def nullable = true + override def toString = s"scalaUDF(${children.mkString(",")})" + /** This method has been generated by this script (1 to 22).map { x => @@ -44,7 +47,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi // scalastyle:off override def eval(input: Row): Any = { - children.size match { + val result = children.size match { case 0 => function.asInstanceOf[() => Any]() case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) case 2 => @@ -343,5 +346,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi children(21).eval(input)) } // scalastyle:on + + ScalaReflection.convertToCatalyst(result) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 374af48b82..4abda21ffe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -204,14 +204,6 @@ case class Sort( */ @DeveloperApi object ExistingRdd { - def convertToCatalyst(a: Any): Any = a match { - case o: Option[_] => o.orNull - case s: Seq[_] => s.map(convertToCatalyst) - case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case other => other - } - def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { data.mapPartitions { iterator => if (iterator.isEmpty) { @@ -223,7 +215,7 @@ object ExistingRdd { bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = convertToCatalyst(r.productElement(i)) + mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) i += 1 } @@ -245,6 +237,7 @@ object ExistingRdd { case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { override def execute() = rdd } + /** * :: DeveloperApi :: * Computes the set of distinct input rows using a HashSet. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 76aa9b0081..ef9b76b1e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ +case class FunctionResult(f1: String, f2: String) + class UDFSuite extends QueryTest { test("Simple UDF") { @@ -33,4 +35,14 @@ class UDFSuite extends QueryTest { registerFunction("strLenScala", (_: String).length + (_:Int)) assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) } + + + test("struct UDF") { + registerFunction("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + + val result= + sql("SELECT returnStruct('test', 'test2') as ret") + .select("ret.f1".attr).first().getString(0) + assert(result == "test") + } } -- cgit v1.2.3