diff options
author | Michael Armbrust <michael@databricks.com> | 2014-11-03 18:04:51 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-11-03 18:04:51 -0800 |
commit | 15b58a2234ab7ba30c9c0cbb536177a3c725e350 (patch) | |
tree | 86761bc9b2e4fb0dfb340510e92bc16af6fa5152 /sql/core/src | |
parent | 28128150e7e0c2b7d1c483e67214bdaef59f7d75 (diff) | |
download | spark-15b58a2234ab7ba30c9c0cbb536177a3c725e350.tar.gz spark-15b58a2234ab7ba30c9c0cbb536177a3c725e350.tar.bz2 spark-15b58a2234ab7ba30c9c0cbb536177a3c725e350.zip |
[SQL] Convert arguments to Scala UDFs
Author: Michael Armbrust <michael@databricks.com>
Closes #3077 from marmbrus/udfsWithUdts and squashes the following commits:
34b5f27 [Michael Armbrust] style
504adef [Michael Armbrust] Convert arguments to Scala UDFs
Diffstat (limited to 'sql/core/src')
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 666235e57f..1806a1dd82 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -60,13 +60,13 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } class UserDefinedTypeSuite extends QueryTest { + val points = Seq( + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) + val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) - test("register user type: MyDenseVector for MyLabeledPoint") { - val points = Seq( - MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) - val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) + test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() assert(labelsArrays.size === 2) @@ -80,4 +80,12 @@ class UserDefinedTypeSuite extends QueryTest { assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) } + + test("UDTs and UDFs") { + registerFunction("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + pointsRDD.registerTempTable("points") + checkAnswer( + sql("SELECT testType(features) from points"), + Seq(Row(true), Row(true))) + } } |