aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala34
1 files changed, 34 insertions, 0 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala
index 2bcbb1983f..91095af0dd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala
@@ -354,4 +354,38 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext {
val df = src.select($"*", max("c").over(winSpec) as "max")
checkAnswer(df, Row(5, Row(0, 3), 5))
}
+
+ test("aggregation and rows between with unbounded + predicate pushdown") {
+ val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value")
+ df.registerTempTable("window_table")
+ val selectList = Seq($"key", $"value",
+ last("key").over(
+ Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)),
+ last("key").over(
+ Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)),
+ last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1)))
+
+ checkAnswer(
+ df.select(selectList: _*).where($"value" < "3"),
+ Seq(Row(1, "1", 1, 1, 1), Row(2, "2", 3, 2, 3), Row(3, "2", 3, 3, 3)))
+ }
+
+ test("aggregation and range between with unbounded + predicate pushdown") {
+ val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value")
+ df.registerTempTable("window_table")
+ val selectList = Seq($"key", $"value",
+ last("value").over(
+ Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1)).equalTo("2")
+ .as("last_v"),
+ avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1))
+ .as("avg_key1"),
+ avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue))
+ .as("avg_key2"),
+ avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 1))
+ .as("avg_key3"))
+
+ checkAnswer(
+ df.select(selectList: _*).where($"value" < 2),
+ Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0)))
+ }
}