diff options
author | Burak Yavuz <brkyvz@gmail.com> | 2015-04-30 21:56:03 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-04-30 21:56:03 -0700 |
commit | b5347a4664625ede6ab9d8ef6558457a34ae423f (patch) | |
tree | 6d068b5ede427fa4e4658d5d3ebcd9798c64950f /python/pyspark | |
parent | 69a739c7f5fd002432ece203957e1458deb2f4c3 (diff) | |
download | spark-b5347a4664625ede6ab9d8ef6558457a34ae423f.tar.gz spark-b5347a4664625ede6ab9d8ef6558457a34ae423f.tar.bz2 spark-b5347a4664625ede6ab9d8ef6558457a34ae423f.zip |
[SPARK-7248] implemented random number generators for DataFrames
Adds the functions `rand` (Uniform Dist) and `randn` (Normal Dist.) as expressions to DataFrames.
cc mengxr rxin
Author: Burak Yavuz <brkyvz@gmail.com>
Closes #5819 from brkyvz/df-rng and squashes the following commits:
50d69d4 [Burak Yavuz] add seed for test that failed
4234c3a [Burak Yavuz] fix Rand expression
13cad5c [Burak Yavuz] couple fixes
7d53953 [Burak Yavuz] waiting for hive tests
b453716 [Burak Yavuz] move radn with seed down
03637f0 [Burak Yavuz] fix broken hive func
c5909eb [Burak Yavuz] deleted old implementation of Rand
6d43895 [Burak Yavuz] implemented random generators
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/sql/functions.py | 25 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 10 |
2 files changed, 34 insertions, 1 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 555c2fa5e7..241f821757 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -67,7 +67,6 @@ _functions = { 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', } - for _name, _doc in _functions.items(): globals()[_name] = _create_function(_name, _doc) del _name, _doc @@ -75,6 +74,30 @@ __all__ += _functions.keys() __all__.sort() +def rand(seed=None): + """ + Generate a random column with i.i.d. samples from U[0.0, 1.0]. + """ + sc = SparkContext._active_spark_context + if seed: + jc = sc._jvm.functions.rand(seed) + else: + jc = sc._jvm.functions.rand() + return Column(jc) + + +def randn(seed=None): + """ + Generate a column with i.i.d. samples from the standard normal distribution. + """ + sc = SparkContext._active_spark_context + if seed: + jc = sc._jvm.functions.randn(seed) + else: + jc = sc._jvm.functions.randn() + return Column(jc) + + def approxCountDistinct(col, rsd=None): """Returns a new :class:`Column` for approximate distinct count of ``col``. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2ffd18ebd7..5640bb5ea2 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -416,6 +416,16 @@ class SQLTests(ReusedPySparkTestCase): assert_close([math.hypot(i, 2 * i) for i in range(10)], df.select(functions.hypot(df.a, df.b)).collect()) + def test_rand_functions(self): + df = self.df + from pyspark.sql import functions + rnd = df.select('key', functions.rand()).collect() + for row in rnd: + assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1] + rndn = df.select('key', functions.randn(5)).collect() + for row in rndn: + assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() |