From b5347a4664625ede6ab9d8ef6558457a34ae423f Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Thu, 30 Apr 2015 21:56:03 -0700 Subject: [SPARK-7248] implemented random number generators for DataFrames Adds the functions `rand` (Uniform Dist) and `randn` (Normal Dist.) as expressions to DataFrames. cc mengxr rxin Author: Burak Yavuz Closes #5819 from brkyvz/df-rng and squashes the following commits: 50d69d4 [Burak Yavuz] add seed for test that failed 4234c3a [Burak Yavuz] fix Rand expression 13cad5c [Burak Yavuz] couple fixes 7d53953 [Burak Yavuz] waiting for hive tests b453716 [Burak Yavuz] move radn with seed down 03637f0 [Burak Yavuz] fix broken hive func c5909eb [Burak Yavuz] deleted old implementation of Rand 6d43895 [Burak Yavuz] implemented random generators --- python/pyspark/sql/functions.py | 25 ++++++++++++++++++++++++- python/pyspark/sql/tests.py | 10 ++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) (limited to 'python/pyspark/sql') diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 555c2fa5e7..241f821757 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -67,7 +67,6 @@ _functions = { 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', } - for _name, _doc in _functions.items(): globals()[_name] = _create_function(_name, _doc) del _name, _doc @@ -75,6 +74,30 @@ __all__ += _functions.keys() __all__.sort() +def rand(seed=None): + """ + Generate a random column with i.i.d. samples from U[0.0, 1.0]. + """ + sc = SparkContext._active_spark_context + if seed: + jc = sc._jvm.functions.rand(seed) + else: + jc = sc._jvm.functions.rand() + return Column(jc) + + +def randn(seed=None): + """ + Generate a column with i.i.d. samples from the standard normal distribution. + """ + sc = SparkContext._active_spark_context + if seed: + jc = sc._jvm.functions.randn(seed) + else: + jc = sc._jvm.functions.randn() + return Column(jc) + + def approxCountDistinct(col, rsd=None): """Returns a new :class:`Column` for approximate distinct count of ``col``. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2ffd18ebd7..5640bb5ea2 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -416,6 +416,16 @@ class SQLTests(ReusedPySparkTestCase): assert_close([math.hypot(i, 2 * i) for i in range(10)], df.select(functions.hypot(df.a, df.b)).collect()) + def test_rand_functions(self): + df = self.df + from pyspark.sql import functions + rnd = df.select('key', functions.rand()).collect() + for row in rnd: + assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1] + rndn = df.select('key', functions.randn(5)).collect() + for row in rndn: + assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() -- cgit v1.2.3