aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala59
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala17
2 files changed, 72 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 7e6d770314..3314e15477 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.types.decimal.Decimal
import scala.language.implicitConversions
+import scala.reflect.runtime.universe.{TypeTag, typeTag}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
@@ -285,4 +286,62 @@ package object dsl {
def writeToFile(path: String) = WriteToFile(path, logicalPlan)
}
}
+
+ case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) {
+ def call(args: Expression*) = ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args)
+ }
+
+ // scalastyle:off
+ /** functionToUdfBuilder 1-22 were generated by this script
+
+ (1 to 22).map { x =>
+ val argTypes = Seq.fill(x)("_").mkString(", ")
+ s"implicit def functionToUdfBuilder[T: TypeTag](func: Function$x[$argTypes, T]) = ScalaUdfBuilder(func)"
+ }
+ */
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function1[_, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function2[_, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function3[_, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function4[_, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function5[_, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function6[_, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function7[_, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function8[_, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function9[_, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function10[_, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+
+ implicit def functionToUdfBuilder[T: TypeTag](func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]) = ScalaUdfBuilder(func)
+ // scalastyle:on
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 45e58afe9d..e70ad891ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.sql
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.test._
/* Implicits */
-import TestSQLContext._
+import org.apache.spark.sql.catalyst.dsl._
+import org.apache.spark.sql.test.TestSQLContext._
class DslQuerySuite extends QueryTest {
- import TestData._
+ import org.apache.spark.sql.TestData._
test("table scan") {
checkAnswer(
@@ -216,4 +215,14 @@ class DslQuerySuite extends QueryTest {
(4, "d") :: Nil)
checkAnswer(lowerCaseData.intersect(upperCaseData), Nil)
}
+
+ test("udf") {
+ val foo = (a: Int, b: String) => a.toString + b
+
+ checkAnswer(
+ // SELECT *, foo(key, value) FROM testData
+ testData.select(Star(None), foo.call('key, 'value)).limit(3),
+ (1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil
+ )
+ }
}