aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-04-30 21:56:03 -0700
committerReynold Xin <rxin@databricks.com>2015-04-30 21:56:03 -0700
commitb5347a4664625ede6ab9d8ef6558457a34ae423f (patch)
tree6d068b5ede427fa4e4658d5d3ebcd9798c64950f /python/pyspark/sql/tests.py
parent69a739c7f5fd002432ece203957e1458deb2f4c3 (diff)
downloadspark-b5347a4664625ede6ab9d8ef6558457a34ae423f.tar.gz
spark-b5347a4664625ede6ab9d8ef6558457a34ae423f.tar.bz2
spark-b5347a4664625ede6ab9d8ef6558457a34ae423f.zip
[SPARK-7248] implemented random number generators for DataFrames
Adds the functions `rand` (Uniform Dist) and `randn` (Normal Dist.) as expressions to DataFrames. cc mengxr rxin Author: Burak Yavuz <brkyvz@gmail.com> Closes #5819 from brkyvz/df-rng and squashes the following commits: 50d69d4 [Burak Yavuz] add seed for test that failed 4234c3a [Burak Yavuz] fix Rand expression 13cad5c [Burak Yavuz] couple fixes 7d53953 [Burak Yavuz] waiting for hive tests b453716 [Burak Yavuz] move radn with seed down 03637f0 [Burak Yavuz] fix broken hive func c5909eb [Burak Yavuz] deleted old implementation of Rand 6d43895 [Burak Yavuz] implemented random generators
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 2ffd18ebd7..5640bb5ea2 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -416,6 +416,16 @@ class SQLTests(ReusedPySparkTestCase):
assert_close([math.hypot(i, 2 * i) for i in range(10)],
df.select(functions.hypot(df.a, df.b)).collect())
+ def test_rand_functions(self):
+ df = self.df
+ from pyspark.sql import functions
+ rnd = df.select('key', functions.rand()).collect()
+ for row in rnd:
+ assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
+ rndn = df.select('key', functions.randn(5)).collect()
+ for row in rndn:
+ assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
+
def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()