aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/sql/functions.py19
1 files changed, 16 insertions, 3 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 5017ab5b36..dac842c0ce 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -467,16 +467,29 @@ def randn(seed=None):
@since(1.5)
def round(col, scale=0):
"""
- Round the value of `e` to `scale` decimal places if `scale` >= 0
+ Round the given value to `scale` decimal places using HALF_UP rounding mode if `scale` >= 0
or at integral part when `scale` < 0.
- >>> sqlContext.createDataFrame([(2.546,)], ['a']).select(round('a', 1).alias('r')).collect()
- [Row(r=2.5)]
+ >>> sqlContext.createDataFrame([(2.5,)], ['a']).select(round('a', 0).alias('r')).collect()
+ [Row(r=3.0)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.round(_to_java_column(col), scale))
+@since(2.0)
+def bround(col, scale=0):
+ """
+ Round the given value to `scale` decimal places using HALF_EVEN rounding mode if `scale` >= 0
+ or at integral part when `scale` < 0.
+
+ >>> sqlContext.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).collect()
+ [Row(r=2.0)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.bround(_to_java_column(col), scale))
+
+
@since(1.5)
def shiftLeft(col, numBits):
"""Shift the given value numBits left.