aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/dataframe.py
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-04-29 19:13:47 -0700
committerReynold Xin <rxin@databricks.com>2015-04-29 19:13:47 -0700
commit5553198fe521fb38b600b7687f7780d89a6e1cb9 (patch)
tree6dabbcd1430e2cdd89fd900a4d06fd2656bbd8d0 /python/pyspark/sql/dataframe.py
parent7143f6e9718bae9cffa0a73df03ba8c9860ee129 (diff)
downloadspark-5553198fe521fb38b600b7687f7780d89a6e1cb9.tar.gz
spark-5553198fe521fb38b600b7687f7780d89a6e1cb9.tar.bz2
spark-5553198fe521fb38b600b7687f7780d89a6e1cb9.zip
[SPARK-7156][SQL] Addressed follow up comments for randomSplit
small fixes regarding comments in PR #5761 cc rxin Author: Burak Yavuz <brkyvz@gmail.com> Closes #5795 from brkyvz/split-followup and squashes the following commits: 369c522 [Burak Yavuz] changed wording a little 1ea456f [Burak Yavuz] Addressed follow up comments
Diffstat (limited to 'python/pyspark/sql/dataframe.py')
-rw-r--r--python/pyspark/sql/dataframe.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 3074af3ed2..5908ebc990 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -437,6 +437,10 @@ class DataFrame(object):
def randomSplit(self, weights, seed=None):
"""Randomly splits this :class:`DataFrame` with the provided weights.
+ :param weights: list of doubles as weights with which to split the DataFrame. Weights will
+ be normalized if they don't sum up to 1.0.
+ :param seed: The seed for sampling.
+
>>> splits = df4.randomSplit([1.0, 2.0], 24)
>>> splits[0].count()
1
@@ -445,7 +449,8 @@ class DataFrame(object):
3
"""
for w in weights:
- assert w >= 0.0, "Negative weight value: %s" % w
+ if w < 0.0:
+ raise ValueError("Weights must be positive. Found weight value: %s" % w)
seed = seed if seed is not None else random.randint(0, sys.maxsize)
rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed))
return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]