From d7dbce8f7da8a7fd01df6633a6043f51161b7d18 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 29 Apr 2015 15:34:05 -0700 Subject: [SPARK-7156][SQL] support RandomSplit in DataFrames This is built on top of kaka1992 's PR #5711 using Logical plans. Author: Burak Yavuz Closes #5761 from brkyvz/random-sample and squashes the following commits: a1fb0aa [Burak Yavuz] remove unrelated file 69669c3 [Burak Yavuz] fix broken test 1ddb3da [Burak Yavuz] copy base 6000328 [Burak Yavuz] added python api and fixed test 3c11d1b [Burak Yavuz] fixed broken test f400ade [Burak Yavuz] fix build errors 2384266 [Burak Yavuz] addressed comments v0.1 e98ebac [Burak Yavuz] [SPARK-7156][SQL] support RandomSplit in DataFrames --- python/pyspark/sql/dataframe.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) (limited to 'python/pyspark') diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d9cbbc68b3..3074af3ed2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -426,7 +426,7 @@ class DataFrame(object): def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. - >>> df.sample(False, 0.5, 97).count() + >>> df.sample(False, 0.5, 42).count() 1 """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction @@ -434,6 +434,22 @@ class DataFrame(object): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) + def randomSplit(self, weights, seed=None): + """Randomly splits this :class:`DataFrame` with the provided weights. + + >>> splits = df4.randomSplit([1.0, 2.0], 24) + >>> splits[0].count() + 1 + + >>> splits[1].count() + 3 + """ + for w in weights: + assert w >= 0.0, "Negative weight value: %s" % w + seed = seed if seed is not None else random.randint(0, sys.maxsize) + rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed)) + return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] + @property def dtypes(self): """Returns all column names and their data types as a list. -- cgit v1.2.3