aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/functions.py25
-rw-r--r--python/pyspark/sql/tests.py10
2 files changed, 34 insertions, 1 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 555c2fa5e7..241f821757 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -67,7 +67,6 @@ _functions = {
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
}
-
for _name, _doc in _functions.items():
globals()[_name] = _create_function(_name, _doc)
del _name, _doc
@@ -75,6 +74,30 @@ __all__ += _functions.keys()
__all__.sort()
+def rand(seed=None):
+ """
+ Generate a random column with i.i.d. samples from U[0.0, 1.0].
+ """
+ sc = SparkContext._active_spark_context
+ if seed:
+ jc = sc._jvm.functions.rand(seed)
+ else:
+ jc = sc._jvm.functions.rand()
+ return Column(jc)
+
+
+def randn(seed=None):
+ """
+ Generate a column with i.i.d. samples from the standard normal distribution.
+ """
+ sc = SparkContext._active_spark_context
+ if seed:
+ jc = sc._jvm.functions.randn(seed)
+ else:
+ jc = sc._jvm.functions.randn()
+ return Column(jc)
+
+
def approxCountDistinct(col, rsd=None):
"""Returns a new :class:`Column` for approximate distinct count of ``col``.
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 2ffd18ebd7..5640bb5ea2 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -416,6 +416,16 @@ class SQLTests(ReusedPySparkTestCase):
assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot(df.a, df.b)).collect())
+ def test_rand_functions(self):
+ df = self.df
+ from pyspark.sql import functions
+ rnd = df.select('key', functions.rand()).collect()
+ for row in rnd:
+ assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
+ rndn = df.select('key', functions.randn(5)).collect()
+ for row in rndn:
+ assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
+
def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()