aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2015-05-22 01:00:16 -0700
committerReynold Xin <rxin@databricks.com>2015-05-22 01:00:16 -0700
commitf6f2eeb17910b5d446dfd61839e37dd698d0860f (patch)
treedabde9be745a52c4d4af0f2177f795dcb8cc005b /sql/hive
parent2728c3df6690c2fcd4af3bd1c604c98ef6d509a5 (diff)
downloadspark-f6f2eeb17910b5d446dfd61839e37dd698d0860f.tar.gz
spark-f6f2eeb17910b5d446dfd61839e37dd698d0860f.tar.bz2
spark-f6f2eeb17910b5d446dfd61839e37dd698d0860f.zip
[SPARK-7322][SQL] Window functions in DataFrame
This closes #6104. Author: Cheng Hao <hao.cheng@intel.com> Author: Reynold Xin <rxin@databricks.com> Closes #6343 from rxin/window-df and squashes the following commits: 026d587 [Reynold Xin] Address code review feedback. dc448fe [Reynold Xin] Fixed Hive tests. 9794d9d [Reynold Xin] Moved Java test package. 9331605 [Reynold Xin] Refactored API. 3313e2a [Reynold Xin] Merge pull request #6104 from chenghao-intel/df_window d625a64 [Cheng Hao] Update the dataframe window API as suggsted c141fb1 [Cheng Hao] hide all of properties of the WindowFunctionDefinition 3b1865f [Cheng Hao] scaladoc typos f3fd2d0 [Cheng Hao] polish the unit test 6847825 [Cheng Hao] Add additional analystcs functions 57e3bc0 [Cheng Hao] typos 24a08ec [Cheng Hao] scaladoc 28222ed [Cheng Hao] fix bug of range/row Frame 1d91865 [Cheng Hao] style issue 53f89f2 [Cheng Hao] remove the over from the functions.scala 964c013 [Cheng Hao] add more unit tests and window functions 64e18a7 [Cheng Hao] Add Window Function support for DataFrame
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java78
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java (renamed from sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java)4
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java (renamed from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java)0
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java (renamed from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java)0
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java (renamed from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java)0
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java (renamed from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java)0
-rw-r--r--sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java (renamed from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java)0
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala219
8 files changed, 300 insertions, 1 deletions
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java
new file mode 100644
index 0000000000..c4828c4717
--- /dev/null
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.hive;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.*;
+import org.apache.spark.sql.expressions.Window;
+import org.apache.spark.sql.hive.HiveContext;
+import org.apache.spark.sql.hive.test.TestHive$;
+
+public class JavaDataFrameSuite {
+ private transient JavaSparkContext sc;
+ private transient HiveContext hc;
+
+ DataFrame df;
+
+ private void checkAnswer(DataFrame actual, List<Row> expected) {
+ String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
+ if (errorMessage != null) {
+ Assert.fail(errorMessage);
+ }
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ hc = TestHive$.MODULE$;
+ sc = new JavaSparkContext(hc.sparkContext());
+
+ List<String> jsonObjects = new ArrayList<String>(10);
+ for (int i = 0; i < 10; i++) {
+ jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}");
+ }
+ df = hc.jsonRDD(sc.parallelize(jsonObjects));
+ df.registerTempTable("window_table");
+ }
+
+ @After
+ public void tearDown() throws IOException {
+ // Clean up tables.
+ hc.sql("DROP TABLE IF EXISTS window_table");
+ }
+
+ @Test
+ public void saveTableAndQueryIt() {
+ checkAnswer(
+ df.select(functions.avg("key").over(
+ Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))),
+ hc.sql("SELECT avg(key) " +
+ "OVER (PARTITION BY value " +
+ " ORDER BY key " +
+ " ROWS BETWEEN 1 preceding and 1 following) " +
+ "FROM window_table").collectAsList());
+ }
+}
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
index 58fe96adab..64d1ce9293 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
@@ -14,7 +14,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.hive;
+
+package test.org.apache.spark.sql.hive;
import java.io.File;
import java.io.IOException;
@@ -36,6 +37,7 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.QueryTest$;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.hive.HiveContext;
import org.apache.spark.sql.hive.test.TestHive$;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java
index 6c4f378bc5..6c4f378bc5 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java
index 808e2986d3..808e2986d3 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java
index f33210ebda..f33210ebda 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java
index a369188d47..a369188d47 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java
index 0165591a7c..0165591a7c 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala
new file mode 100644
index 0000000000..6cea6776c8
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.spark.sql.{Row, QueryTest}
+import org.apache.spark.sql.expressions.Window
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
+
+class HiveDataFrameWindowSuite extends QueryTest {
+
+ test("reuse window partitionBy") {
+ val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ val w = Window.partitionBy("key").orderBy("value")
+
+ checkAnswer(
+ df.select(
+ lead("key").over(w),
+ lead("value").over(w)),
+ Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil)
+ }
+
+ test("reuse window orderBy") {
+ val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ val w = Window.orderBy("value").partitionBy("key")
+
+ checkAnswer(
+ df.select(
+ lead("key").over(w),
+ lead("value").over(w)),
+ Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil)
+ }
+
+ test("lead") {
+ val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+
+ checkAnswer(
+ df.select(
+ lead("value").over(Window.partitionBy($"key").orderBy($"value"))),
+ sql(
+ """SELECT
+ | lead(value) OVER (PARTITION BY key ORDER BY value)
+ | FROM window_table""".stripMargin).collect())
+ }
+
+ test("lag") {
+ val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+
+ checkAnswer(
+ df.select(
+ lag("value").over(
+ Window.partitionBy($"key")
+ .orderBy($"value"))),
+ sql(
+ """SELECT
+ | lag(value) OVER (PARTITION BY key ORDER BY value)
+ | FROM window_table""".stripMargin).collect())
+ }
+
+ test("lead with default value") {
+ val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"),
+ (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+ checkAnswer(
+ df.select(
+ lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))),
+ sql(
+ """SELECT
+ | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value)
+ | FROM window_table""".stripMargin).collect())
+ }
+
+ test("lag with default value") {
+ val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"),
+ (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+ checkAnswer(
+ df.select(
+ lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))),
+ sql(
+ """SELECT
+ | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value)
+ | FROM window_table""".stripMargin).collect())
+ }
+
+ test("rank functions in unspecific window") {
+ val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+ checkAnswer(
+ df.select(
+ $"key",
+ max("key").over(Window.partitionBy("value").orderBy("key")),
+ min("key").over(Window.partitionBy("value").orderBy("key")),
+ mean("key").over(Window.partitionBy("value").orderBy("key")),
+ count("key").over(Window.partitionBy("value").orderBy("key")),
+ sum("key").over(Window.partitionBy("value").orderBy("key")),
+ ntile("key").over(Window.partitionBy("value").orderBy("key")),
+ ntile($"key").over(Window.partitionBy("value").orderBy("key")),
+ rowNumber().over(Window.partitionBy("value").orderBy("key")),
+ denseRank().over(Window.partitionBy("value").orderBy("key")),
+ rank().over(Window.partitionBy("value").orderBy("key")),
+ cumeDist().over(Window.partitionBy("value").orderBy("key")),
+ percentRank().over(Window.partitionBy("value").orderBy("key"))),
+ sql(
+ s"""SELECT
+ |key,
+ |max(key) over (partition by value order by key),
+ |min(key) over (partition by value order by key),
+ |avg(key) over (partition by value order by key),
+ |count(key) over (partition by value order by key),
+ |sum(key) over (partition by value order by key),
+ |ntile(key) over (partition by value order by key),
+ |ntile(key) over (partition by value order by key),
+ |row_number() over (partition by value order by key),
+ |dense_rank() over (partition by value order by key),
+ |rank() over (partition by value order by key),
+ |cume_dist() over (partition by value order by key),
+ |percent_rank() over (partition by value order by key)
+ |FROM window_table""".stripMargin).collect())
+ }
+
+ test("aggregation and rows between") {
+ val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+ checkAnswer(
+ df.select(
+ avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))),
+ sql(
+ """SELECT
+ | avg(key) OVER
+ | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 2 following)
+ | FROM window_table""".stripMargin).collect())
+ }
+
+ test("aggregation and range betweens") {
+ val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+ checkAnswer(
+ df.select(
+ avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))),
+ sql(
+ """SELECT
+ | avg(key) OVER
+ | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following)
+ | FROM window_table""".stripMargin).collect())
+ }
+
+ test("aggregation and rows betweens with unbounded") {
+ val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value")
+ df.registerTempTable("window_table")
+ checkAnswer(
+ df.select(
+ $"key",
+ last("value").over(
+ Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)),
+ last("value").over(
+ Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)),
+ last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 3))),
+ sql(
+ """SELECT
+ | key,
+ | last_value(value) OVER
+ | (PARTITION BY value ORDER BY key ROWS between current row and unbounded following),
+ | last_value(value) OVER
+ | (PARTITION BY value ORDER BY key ROWS between unbounded preceding and current row),
+ | last_value(value) OVER
+ | (PARTITION BY value ORDER BY key ROWS between 1 preceding and 3 following)
+ | FROM window_table""".stripMargin).collect())
+ }
+
+ test("aggregation and range betweens with unbounded") {
+ val df = Seq((1, "1"), (2, "2"), (2, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+ checkAnswer(
+ df.select(
+ $"key",
+ last("value").over(
+ Window.partitionBy($"value").orderBy($"key").rangeBetween(1, Long.MaxValue))
+ .equalTo("2")
+ .as("last_v"),
+ avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1))
+ .as("avg_key1"),
+ avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue))
+ .as("avg_key2"),
+ avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0))
+ .as("avg_key3")
+ ),
+ sql(
+ """SELECT
+ | key,
+ | last_value(value) OVER
+ | (PARTITION BY value ORDER BY key RANGE 1 preceding) == "2",
+ | avg(key) OVER
+ | (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following),
+ | avg(key) OVER
+ | (PARTITION BY value ORDER BY key RANGE BETWEEN current row and unbounded following),
+ | avg(key) OVER
+ | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row)
+ | FROM window_table""".stripMargin).collect())
+ }
+}