aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@appier.com>2015-11-06 10:52:04 -0800
committerDavies Liu <davies.liu@gmail.com>2015-11-06 10:52:04 -0800
commit574141a29835ce78d68c97bb54336cf4fd3c39d3 (patch)
tree58524a8974e6bb83da3584761a7ec90ff0b913df
parentcf69ce136590fea51843bc54f44f0f45c7d0ac36 (diff)
downloadspark-574141a29835ce78d68c97bb54336cf4fd3c39d3.tar.gz
spark-574141a29835ce78d68c97bb54336cf4fd3c39d3.tar.bz2
spark-574141a29835ce78d68c97bb54336cf4fd3c39d3.zip
[SPARK-9162] [SQL] Implement code generation for ScalaUDF
JIRA: https://issues.apache.org/jira/browse/SPARK-9162 Currently ScalaUDF extends CodegenFallback and doesn't provide code generation implementation. This path implements code generation for ScalaUDF. Author: Liang-Chi Hsieh <viirya@appier.com> Closes #9270 from viirya/scalaudf-codegen.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala85
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala41
2 files changed, 124 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 11c7950c06..3388cc20a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types.DataType
/**
@@ -31,7 +31,7 @@ case class ScalaUDF(
dataType: DataType,
children: Seq[Expression],
inputTypes: Seq[DataType] = Nil)
- extends Expression with ImplicitCastInputTypes with CodegenFallback {
+ extends Expression with ImplicitCastInputTypes {
override def nullable: Boolean = true
@@ -60,6 +60,10 @@ case class ScalaUDF(
*/
+ // Accessors used in genCode
+ def userDefinedFunc(): AnyRef = function
+ def getChildren(): Seq[Expression] = children
+
private[this] val f = children.size match {
case 0 =>
val func = function.asInstanceOf[() => Any]
@@ -960,6 +964,83 @@ case class ScalaUDF(
}
// scalastyle:on
+
+ // Generate codes used to convert the arguments to Scala type for user-defined funtions
+ private[this] def genCodeForConverter(ctx: CodeGenContext, index: Int): String = {
+ val converterClassName = classOf[Any => Any].getName
+ val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"
+ val expressionClassName = classOf[Expression].getName
+ val scalaUDFClassName = classOf[ScalaUDF].getName
+
+ val converterTerm = ctx.freshName("converter")
+ val expressionIdx = ctx.references.size - 1
+ ctx.addMutableState(converterClassName, converterTerm,
+ s"this.$converterTerm = ($converterClassName)$typeConvertersClassName" +
+ s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" +
+ s"expressions[$expressionIdx]).getChildren().apply($index))).dataType());")
+ converterTerm
+ }
+
+ override def genCode(
+ ctx: CodeGenContext,
+ ev: GeneratedExpressionCode): String = {
+
+ ctx.references += this
+
+ val scalaUDFClassName = classOf[ScalaUDF].getName
+ val converterClassName = classOf[Any => Any].getName
+ val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"
+ val expressionClassName = classOf[Expression].getName
+
+ // Generate codes used to convert the returned value of user-defined functions to Catalyst type
+ val catalystConverterTerm = ctx.freshName("catalystConverter")
+ val catalystConverterTermIdx = ctx.references.size - 1
+ ctx.addMutableState(converterClassName, catalystConverterTerm,
+ s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" +
+ s".createToCatalystConverter((($scalaUDFClassName)expressions" +
+ s"[$catalystConverterTermIdx]).dataType());")
+
+ val resultTerm = ctx.freshName("result")
+
+ // This must be called before children expressions' codegen
+ // because ctx.references is used in genCodeForConverter
+ val converterTerms = (0 until children.size).map(genCodeForConverter(ctx, _))
+
+ // Initialize user-defined function
+ val funcClassName = s"scala.Function${children.size}"
+
+ val funcTerm = ctx.freshName("udf")
+ val funcExpressionIdx = ctx.references.size - 1
+ ctx.addMutableState(funcClassName, funcTerm,
+ s"this.$funcTerm = ($funcClassName)((($scalaUDFClassName)expressions" +
+ s"[$funcExpressionIdx]).userDefinedFunc());")
+
+ // codegen for children expressions
+ val evals = children.map(_.gen(ctx))
+
+ // Generate the codes for expressions and calling user-defined function
+ // We need to get the boxedType of dataType's javaType here. Because for the dataType
+ // such as IntegerType, its javaType is `int` and the returned type of user-defined
+ // function is Object. Trying to convert an Object to `int` will cause casting exception.
+ val evalCode = evals.map(_.code).mkString
+ val funcArguments = converterTerms.zip(evals).map {
+ case (converter, eval) => s"$converter.apply(${eval.value})"
+ }.mkString(",")
+ val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " +
+ s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" +
+ s".apply($funcTerm.apply($funcArguments));"
+
+ evalCode + s"""
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ Boolean ${ev.isNull};
+
+ $callFunc
+
+ ${ev.value} = $resultTerm;
+ ${ev.isNull} = $resultTerm == null;
+ """
+ }
+
private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType)
override def eval(input: InternalRow): Any = converter(f(input))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index e0435a0dba..9837fa6bdb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -191,4 +191,45 @@ class UDFSuite extends QueryTest with SharedSQLContext {
// pass a decimal to intExpected.
assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
}
+
+ test("udf in different types") {
+ sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s) })
+ sqlContext.udf.register("decimalDataFunc",
+ (a: java.math.BigDecimal, b: java.math.BigDecimal) => { (a, b) })
+ sqlContext.udf.register("binaryDataFunc", (a: Array[Byte], b: Int) => { (a, b) })
+ sqlContext.udf.register("arrayDataFunc",
+ (data: Seq[Int], nestedData: Seq[Seq[Int]]) => { (data, nestedData) })
+ sqlContext.udf.register("mapDataFunc",
+ (data: scala.collection.Map[Int, String]) => { data })
+ sqlContext.udf.register("complexDataFunc",
+ (m: Map[String, Int], a: Seq[Int], b: Boolean) => { (m, a, b) } )
+
+ checkAnswer(
+ sql("SELECT tmp.t.* FROM (SELECT testDataFunc(key, value) AS t from testData) tmp").toDF(),
+ testData)
+ checkAnswer(
+ sql("""
+ | SELECT tmp.t.* FROM
+ | (SELECT decimalDataFunc(a, b) AS t FROM decimalData) tmp
+ """.stripMargin).toDF(), decimalData)
+ checkAnswer(
+ sql("""
+ | SELECT tmp.t.* FROM
+ | (SELECT binaryDataFunc(a, b) AS t FROM binaryData) tmp
+ """.stripMargin).toDF(), binaryData)
+ checkAnswer(
+ sql("""
+ | SELECT tmp.t.* FROM
+ | (SELECT arrayDataFunc(data, nestedData) AS t FROM arrayData) tmp
+ """.stripMargin).toDF(), arrayData.toDF())
+ checkAnswer(
+ sql("""
+ | SELECT mapDataFunc(data) AS t FROM mapData
+ """.stripMargin).toDF(), mapData.toDF())
+ checkAnswer(
+ sql("""
+ | SELECT tmp.t.* FROM
+ | (SELECT complexDataFunc(m, a, b) AS t FROM complexData) tmp
+ """.stripMargin).toDF(), complexData.select("m", "a", "b"))
+ }
}