diff options
Diffstat (limited to 'sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala')
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index 2bcbb1983f..91095af0dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -354,4 +354,38 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { val df = src.select($"*", max("c").over(winSpec) as "max") checkAnswer(df, Row(5, Row(0, 3), 5)) } + + test("aggregation and rows between with unbounded + predicate pushdown") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + df.registerTempTable("window_table") + val selectList = Seq($"key", $"value", + last("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), + last("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), + last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))) + + checkAnswer( + df.select(selectList: _*).where($"value" < "3"), + Seq(Row(1, "1", 1, 1, 1), Row(2, "2", 3, 2, 3), Row(3, "2", 3, 3, 3))) + } + + test("aggregation and range between with unbounded + predicate pushdown") { + val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + val selectList = Seq($"key", $"value", + last("value").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1)).equalTo("2") + .as("last_v"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) + .as("avg_key1"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue)) + .as("avg_key2"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 1)) + .as("avg_key3")) + + checkAnswer( + df.select(selectList: _*).where($"value" < 2), + Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0))) + } } |