aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-11-03 18:04:51 -0800
committerMichael Armbrust <michael@databricks.com>2014-11-03 18:04:51 -0800
commit15b58a2234ab7ba30c9c0cbb536177a3c725e350 (patch)
tree86761bc9b2e4fb0dfb340510e92bc16af6fa5152 /sql/core/src
parent28128150e7e0c2b7d1c483e67214bdaef59f7d75 (diff)
downloadspark-15b58a2234ab7ba30c9c0cbb536177a3c725e350.tar.gz
spark-15b58a2234ab7ba30c9c0cbb536177a3c725e350.tar.bz2
spark-15b58a2234ab7ba30c9c0cbb536177a3c725e350.zip
[SQL] Convert arguments to Scala UDFs
Author: Michael Armbrust <michael@databricks.com> Closes #3077 from marmbrus/udfsWithUdts and squashes the following commits: 34b5f27 [Michael Armbrust] style 504adef [Michael Armbrust] Convert arguments to Scala UDFs
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala18
1 files changed, 13 insertions, 5 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 666235e57f..1806a1dd82 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -60,13 +60,13 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
}
class UserDefinedTypeSuite extends QueryTest {
+ val points = Seq(
+ MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
+ MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0))))
+ val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points)
- test("register user type: MyDenseVector for MyLabeledPoint") {
- val points = Seq(
- MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
- MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0))))
- val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points)
+ test("register user type: MyDenseVector for MyLabeledPoint") {
val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v }
val labelsArrays: Array[Double] = labels.collect()
assert(labelsArrays.size === 2)
@@ -80,4 +80,12 @@ class UserDefinedTypeSuite extends QueryTest {
assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0))))
assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0))))
}
+
+ test("UDTs and UDFs") {
+ registerFunction("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector])
+ pointsRDD.registerTempTable("points")
+ checkAnswer(
+ sql("SELECT testType(features) from points"),
+ Seq(Row(true), Row(true)))
+ }
}