aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-04-29 15:34:05 -0700
committerReynold Xin <rxin@databricks.com>2015-04-29 15:34:05 -0700
commitd7dbce8f7da8a7fd01df6633a6043f51161b7d18 (patch)
tree6e57f1d7614527f4796071a541171cd07f8d98d2 /python
parentc9d530e2e5123dbd4fd13fc487c890d6076b24bf (diff)
downloadspark-d7dbce8f7da8a7fd01df6633a6043f51161b7d18.tar.gz
spark-d7dbce8f7da8a7fd01df6633a6043f51161b7d18.tar.bz2
spark-d7dbce8f7da8a7fd01df6633a6043f51161b7d18.zip
[SPARK-7156][SQL] support RandomSplit in DataFrames
This is built on top of kaka1992 's PR #5711 using Logical plans. Author: Burak Yavuz <brkyvz@gmail.com> Closes #5761 from brkyvz/random-sample and squashes the following commits: a1fb0aa [Burak Yavuz] remove unrelated file 69669c3 [Burak Yavuz] fix broken test 1ddb3da [Burak Yavuz] copy base 6000328 [Burak Yavuz] added python api and fixed test 3c11d1b [Burak Yavuz] fixed broken test f400ade [Burak Yavuz] fix build errors 2384266 [Burak Yavuz] addressed comments v0.1 e98ebac [Burak Yavuz] [SPARK-7156][SQL] support RandomSplit in DataFrames
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/dataframe.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index d9cbbc68b3..3074af3ed2 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -426,7 +426,7 @@ class DataFrame(object):
def sample(self, withReplacement, fraction, seed=None):
"""Returns a sampled subset of this :class:`DataFrame`.
- >>> df.sample(False, 0.5, 97).count()
+ >>> df.sample(False, 0.5, 42).count()
1
"""
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
@@ -434,6 +434,22 @@ class DataFrame(object):
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)
+ def randomSplit(self, weights, seed=None):
+ """Randomly splits this :class:`DataFrame` with the provided weights.
+
+ >>> splits = df4.randomSplit([1.0, 2.0], 24)
+ >>> splits[0].count()
+ 1
+
+ >>> splits[1].count()
+ 3
+ """
+ for w in weights:
+ assert w >= 0.0, "Negative weight value: %s" % w
+ seed = seed if seed is not None else random.randint(0, sys.maxsize)
+ rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed))
+ return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
+
@property
def dtypes(self):
"""Returns all column names and their data types as a list.