aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVida Ha <vida@databricks.com>2014-10-09 13:13:31 -0700
committerMichael Armbrust <michael@databricks.com>2014-10-09 13:13:31 -0700
commitb77a02f41c60d869f48b65e72ed696c05b30bc48 (patch)
tree0b65df55fbb77df68aafbb452137255cc61ac30d
parent73bf3f2e0c03216aa29c25fea2d97205b5977903 (diff)
downloadspark-b77a02f41c60d869f48b65e72ed696c05b30bc48.tar.gz
spark-b77a02f41c60d869f48b65e72ed696c05b30bc48.tar.bz2
spark-b77a02f41c60d869f48b65e72ed696c05b30bc48.zip
[SPARK-3752][SQL]: Add tests for different UDF's
Author: Vida Ha <vida@databricks.com> Closes #2621 from vidaha/vida/SPARK-3752 and squashes the following commits: d7fdbbc [Vida Ha] Add tests for different UDF's
-rw-r--r--sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java26
-rw-r--r--sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java51
-rw-r--r--sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java38
-rw-r--r--sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java26
-rw-r--r--sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java28
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala111
6 files changed, 265 insertions, 15 deletions
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java
new file mode 100644
index 0000000000..6c4f378bc5
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+public class UDFIntegerToString extends UDF {
+ public String evaluate(Integer i) {
+ return i.toString();
+ }
+}
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java
new file mode 100644
index 0000000000..d2d39a8c4d
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+import java.util.List;
+
+public class UDFListListInt extends UDF {
+ /**
+ *
+ * @param obj
+ * SQL schema: array<struct<x: int, y: int, z: int>>
+ * Java Type: List<List<Integer>>
+ * @return
+ */
+ public long evaluate(Object obj) {
+ if (obj == null) {
+ return 0l;
+ }
+ List<List> listList = (List<List>) obj;
+ long retVal = 0;
+ for (List aList : listList) {
+ @SuppressWarnings("unchecked")
+ List<Object> list = (List<Object>) aList;
+ @SuppressWarnings("unchecked")
+ Integer someInt = (Integer) list.get(1);
+ try {
+ retVal += (long) (someInt.intValue());
+ } catch (NullPointerException e) {
+ System.out.println(e);
+ }
+ }
+ return retVal;
+ }
+}
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java
new file mode 100644
index 0000000000..efd34df293
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+import java.util.List;
+import org.apache.commons.lang.StringUtils;
+
+public class UDFListString extends UDF {
+
+ public String evaluate(Object a) {
+ if (a == null) {
+ return null;
+ }
+ @SuppressWarnings("unchecked")
+ List<Object> s = (List<Object>) a;
+
+ return StringUtils.join(s, ',');
+ }
+
+
+}
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java
new file mode 100644
index 0000000000..a369188d47
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+public class UDFStringString extends UDF {
+ public String evaluate(String s1, String s2) {
+ return s1 + " " + s2;
+ }
+}
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java
new file mode 100644
index 0000000000..0165591a7c
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+public class UDFTwoListList extends UDF {
+ public String evaluate(Object o1, Object o2) {
+ UDFListListInt udf = new UDFListListInt();
+
+ return String.format("%s, %s", udf.evaluate(o1), udf.evaluate(o2));
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index e4324e9528..872f28d514 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -17,33 +17,37 @@
package org.apache.spark.sql.hive.execution
-import java.io.{DataOutput, DataInput}
+import java.io.{DataInput, DataOutput}
import java.util
import java.util.Properties
-import org.apache.spark.util.Utils
-
-import scala.collection.JavaConversions._
-
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.hive.serde2.{SerDeStats, AbstractSerDe}
-import org.apache.hadoop.io.Writable
-import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorFactory, ObjectInspector}
-
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject
-
-import org.apache.spark.sql.Row
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
+import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats}
+import org.apache.hadoop.io.Writable
+import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.hive.test.TestHive
-import org.apache.spark.sql.hive.test.TestHive._
+
+import org.apache.spark.util.Utils
+
+import scala.collection.JavaConversions._
case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int)
+// Case classes for the custom UDF's.
+case class IntegerCaseClass(i: Int)
+case class ListListIntCaseClass(lli: Seq[(Int, Int, Int)])
+case class StringCaseClass(s: String)
+case class ListStringCaseClass(l: Seq[String])
+
/**
* A test suite for Hive custom UDFs.
*/
-class HiveUdfSuite extends HiveComparisonTest {
+class HiveUdfSuite extends QueryTest {
+ import TestHive._
test("spark sql udf test that returns a struct") {
registerFunction("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5))
@@ -81,7 +85,84 @@ class HiveUdfSuite extends HiveComparisonTest {
}
test("SPARK-2693 udaf aggregates test") {
- assert(sql("SELECT percentile(key,1) FROM src").first === sql("SELECT max(key) FROM src").first)
+ checkAnswer(sql("SELECT percentile(key,1) FROM src LIMIT 1"),
+ sql("SELECT max(key) FROM src").collect().toSeq)
+ }
+
+ test("UDFIntegerToString") {
+ val testData = TestHive.sparkContext.parallelize(
+ IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil)
+ testData.registerTempTable("integerTable")
+
+ sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'")
+ checkAnswer(
+ sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(),
+ Seq(Seq("1"), Seq("2")))
+ sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString")
+
+ TestHive.reset()
+ }
+
+ test("UDFListListInt") {
+ val testData = TestHive.sparkContext.parallelize(
+ ListListIntCaseClass(Nil) ::
+ ListListIntCaseClass(Seq((1, 2, 3))) ::
+ ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil)
+ testData.registerTempTable("listListIntTable")
+
+ sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'")
+ checkAnswer(
+ sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(),
+ Seq(Seq(0), Seq(2), Seq(13)))
+ sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt")
+
+ TestHive.reset()
+ }
+
+ test("UDFListString") {
+ val testData = TestHive.sparkContext.parallelize(
+ ListStringCaseClass(Seq("a", "b", "c")) ::
+ ListStringCaseClass(Seq("d", "e")) :: Nil)
+ testData.registerTempTable("listStringTable")
+
+ sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'")
+ checkAnswer(
+ sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(),
+ Seq(Seq("a,b,c"), Seq("d,e")))
+ sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString")
+
+ TestHive.reset()
+ }
+
+ test("UDFStringString") {
+ val testData = TestHive.sparkContext.parallelize(
+ StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil)
+ testData.registerTempTable("stringTable")
+
+ sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'")
+ checkAnswer(
+ sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(),
+ Seq(Seq("hello world"), Seq("hello goodbye")))
+ sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf")
+
+ TestHive.reset()
+ }
+
+ test("UDFTwoListList") {
+ val testData = TestHive.sparkContext.parallelize(
+ ListListIntCaseClass(Nil) ::
+ ListListIntCaseClass(Seq((1, 2, 3))) ::
+ ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) ::
+ Nil)
+ testData.registerTempTable("TwoListTable")
+
+ sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'")
+ checkAnswer(
+ sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(),
+ Seq(Seq("0, 0"), Seq("2, 2"), Seq("13, 13")))
+ sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList")
+
+ TestHive.reset()
}
}