aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/functions.py7
-rw-r--r--python/pyspark/sql/tests.py23
-rw-r--r--python/pyspark/sql/window.py2
3 files changed, 28 insertions, 4 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index e98979533f..41dfee9f54 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -530,9 +530,10 @@ def lead(col, count=1, default=None):
@since(1.4)
def ntile(n):
"""
- Window function: returns a group id from 1 to `n` (inclusive) in a round-robin fashion in
- a window partition. Fow example, if `n` is 3, the first row will get 1, the second row will
- get 2, the third row will get 3, and the fourth row will get 1...
+ Window function: returns the ntile group id (from 1 to `n` inclusive)
+ in an ordered window partition. Fow example, if `n` is 4, the first
+ quarter of the rows will get value 1, the second quarter will get 2,
+ the third quarter will get 3, and the last quarter will get 4.
This is equivalent to the NTILE function in SQL.
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 38c83c427a..9b748101b5 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1124,5 +1124,28 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
for r, ex in zip(rs, expected):
self.assertEqual(tuple(r), ex[:len(r)])
+ def test_window_functions_without_partitionBy(self):
+ df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+ w = Window.orderBy("key", df.value)
+ from pyspark.sql import functions as F
+ sel = df.select(df.value, df.key,
+ F.max("key").over(w.rowsBetween(0, 1)),
+ F.min("key").over(w.rowsBetween(0, 1)),
+ F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))),
+ F.rowNumber().over(w),
+ F.rank().over(w),
+ F.denseRank().over(w),
+ F.ntile(2).over(w))
+ rs = sorted(sel.collect())
+ expected = [
+ ("1", 1, 1, 1, 4, 1, 1, 1, 1),
+ ("2", 1, 1, 1, 4, 2, 2, 2, 1),
+ ("2", 1, 2, 1, 4, 3, 2, 2, 2),
+ ("2", 2, 2, 2, 4, 4, 4, 3, 2)
+ ]
+ for r, ex in zip(rs, expected):
+ self.assertEqual(tuple(r), ex[:len(r)])
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py
index c74745c726..eaf4d7e986 100644
--- a/python/pyspark/sql/window.py
+++ b/python/pyspark/sql/window.py
@@ -64,7 +64,7 @@ class Window(object):
Creates a :class:`WindowSpec` with the partitioning defined.
"""
sc = SparkContext._active_spark_context
- jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols))
+ jspec = sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(_to_java_cols(cols))
return WindowSpec(jspec)