diff options
author | Yin Huai <yhuai@databricks.com> | 2015-08-06 17:03:14 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-08-06 17:03:21 -0700 |
commit | 75b4e5ab306166ea5f9c65eb6d0a14dfc26ef2a9 (patch) | |
tree | 183268b4d81c6e773d1c305cfff3bcada319cc06 /python | |
parent | 985e454cb6a8b97e1b1d4a1c8b2fa86db1830d22 (diff) | |
download | spark-75b4e5ab306166ea5f9c65eb6d0a14dfc26ef2a9.tar.gz spark-75b4e5ab306166ea5f9c65eb6d0a14dfc26ef2a9.tar.bz2 spark-75b4e5ab306166ea5f9c65eb6d0a14dfc26ef2a9.zip |
[SPARK-9691] [SQL] PySpark SQL rand function treats seed 0 as no seed
https://issues.apache.org/jira/browse/SPARK-9691
jkbradley rxin
Author: Yin Huai <yhuai@databricks.com>
Closes #7999 from yhuai/pythonRand and squashes the following commits:
4187e0c [Yin Huai] Regression test.
a985ef9 [Yin Huai] Use "if seed is not None" instead "if seed" because "if seed" returns false when seed is 0.
(cherry picked from commit baf4587a569b49e39020c04c2785041bdd00789b)
Signed-off-by: Reynold Xin <rxin@databricks.com>
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql/functions.py | 4 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 10 |
2 files changed, 12 insertions, 2 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b5c6a01f18..95f46044d3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -268,7 +268,7 @@ def rand(seed=None): """Generates a random column with i.i.d. samples from U[0.0, 1.0]. """ sc = SparkContext._active_spark_context - if seed: + if seed is not None: jc = sc._jvm.functions.rand(seed) else: jc = sc._jvm.functions.rand() @@ -280,7 +280,7 @@ def randn(seed=None): """Generates a column with i.i.d. samples from the standard normal distribution. """ sc = SparkContext._active_spark_context - if seed: + if seed is not None: jc = sc._jvm.functions.randn(seed) else: jc = sc._jvm.functions.randn() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ebd3ea8db6..1e3444dd9e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -629,6 +629,16 @@ class SQLTests(ReusedPySparkTestCase): for row in rndn: assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] + # If the specified seed is 0, we should use it. + # https://issues.apache.org/jira/browse/SPARK-9691 + rnd1 = df.select('key', functions.rand(0)).collect() + rnd2 = df.select('key', functions.rand(0)).collect() + self.assertEqual(sorted(rnd1), sorted(rnd2)) + + rndn1 = df.select('key', functions.randn(0)).collect() + rndn2 = df.select('key', functions.randn(0)).collect() + self.assertEqual(sorted(rndn1), sorted(rndn2)) + def test_between_function(self): df = self.sc.parallelize([ Row(a=1, b=2, c=3), |