aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2014-11-03 13:20:33 -0800
committerMichael Armbrust <michael@databricks.com>2014-11-03 13:20:33 -0800
commitc238fb423d1011bd1b1e6201d769b72e52664fc6 (patch)
treea1d4de68b51efcd5f0d0c29c7732545f45edee96 /sql
parent24544fbce05665ab4999a1fe5aac434d29cd912c (diff)
downloadspark-c238fb423d1011bd1b1e6201d769b72e52664fc6.tar.gz
spark-c238fb423d1011bd1b1e6201d769b72e52664fc6.tar.bz2
spark-c238fb423d1011bd1b1e6201d769b72e52664fc6.zip
[SPARK-4202][SQL] Simple DSL support for Scala UDF
This feature is based on an offline discussion with mengxr, hopefully can be useful for the new MLlib pipeline API. For the following test snippet ```scala case class KeyValue(key: Int, value: String) val testData = sc.parallelize(1 to 10).map(i => KeyValue(i, i.toString)).toSchemaRDD def foo(a: Int, b: String) => a.toString + b ``` the newly introduced DSL enables the following syntax ```scala import org.apache.spark.sql.catalyst.dsl._ testData.select(Star(None), foo.call('key, 'value) as 'result) ``` which is equivalent to ```scala testData.registerTempTable("testData") sqlContext.registerFunction("foo", foo) sql("SELECT *, foo(key, value) AS result FROM testData") ``` Author: Cheng Lian <lian@databricks.com> Closes #3067 from liancheng/udf-dsl and squashes the following commits: f132818 [Cheng Lian] Adds DSL support for Scala UDF
Diffstat (limited to 'sql')
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala59
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala17
2 files changed, 72 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 7e6d770314..3314e15477 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.types.decimal.Decimal
import scala.language.implicitConversions
+import scala.reflect.runtime.universe.{TypeTag, typeTag}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
@@ -285,4 +286,62 @@ package object dsl {
def writeToFile(path: String) = WriteToFile(path, logicalPlan)
}
}
+
+ case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) {
+ def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args)
+ }
+
+ // scalastyle:off
+ /** functionToUdfBuilder 1-22 were generated by this script
+
+ (1 to 22).map { x =>
+ val argTypes = Seq.fill(x)("_").mkString(", ")
+ s"implicit def functionToUdfBuilder[T: TypeTag](func: Function$x[$argTypes, T]) = ScalaUdfBuilder(func)"
+ }
+ */
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function1[_, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function2[_, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function3[_, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function4[_, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function5[_, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function6[_, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function7[_, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function8[_, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function9[_, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function10[_, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+ // scalastyle:on
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 45e58afe9d..e70ad891ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.sql
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.test._
/* Implicits */
-import TestSQLContext._
+import org.apache.spark.sql.catalyst.dsl._
+import org.apache.spark.sql.test.TestSQLContext._
class DslQuerySuite extends QueryTest {
- import TestData._
+ import org.apache.spark.sql.TestData._
test("table scan") {
checkAnswer(
@@ -216,4 +215,14 @@ class DslQuerySuite extends QueryTest {
(4, "d") :: Nil)
checkAnswer(lowerCaseData.intersect(upperCaseData), Nil)
}
+
+ test("udf") {
+ val foo = (a: Int, b: String) => a.toString + b
+
+ checkAnswer(
+ // SELECT *, foo(key, value) FROM testData
+ testData.select(Star(None), foo.call('key, 'value)).limit(3),
+ (1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil
+ )
+ }
}