aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/dataframe.py
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-06-23 17:46:29 -0700
committerReynold Xin <rxin@databricks.com>2015-06-23 17:46:29 -0700
commit0401cbaa8ee51c71f43604f338b65022a479da0a (patch)
tree7c8c5fd448576146a6677a907a6d8acf0e3b16c0 /python/pyspark/sql/dataframe.py
parent111d6b9b8a584b962b6ae80c7aa8c45845ce0099 (diff)
downloadspark-0401cbaa8ee51c71f43604f338b65022a479da0a.tar.gz
spark-0401cbaa8ee51c71f43604f338b65022a479da0a.tar.bz2
spark-0401cbaa8ee51c71f43604f338b65022a479da0a.zip
[SPARK-7157][SQL] add sampleBy to DataFrame
Add `sampleBy` to DataFrame. rxin Author: Xiangrui Meng <meng@databricks.com> Closes #6769 from mengxr/SPARK-7157 and squashes the following commits: 991f26f [Xiangrui Meng] fix seed 4a14834 [Xiangrui Meng] move sampleBy to stat 832f7cc [Xiangrui Meng] add sampleBy to DataFrame
Diffstat (limited to 'python/pyspark/sql/dataframe.py')
-rw-r--r--python/pyspark/sql/dataframe.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 152b87351d..213338dfe5 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -448,6 +448,41 @@ class DataFrame(object):
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)
+ @since(1.5)
+ def sampleBy(self, col, fractions, seed=None):
+ """
+ Returns a stratified sample without replacement based on the
+ fraction given on each stratum.
+
+ :param col: column that defines strata
+ :param fractions:
+ sampling fraction for each stratum. If a stratum is not
+ specified, we treat its fraction as zero.
+ :param seed: random seed
+ :return: a new DataFrame that represents the stratified sample
+
+ >>> from pyspark.sql.functions import col
+ >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key"))
+ >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)
+ >>> sampled.groupBy("key").count().orderBy("key").show()
+ +---+-----+
+ |key|count|
+ +---+-----+
+ | 0| 5|
+ | 1| 8|
+ +---+-----+
+ """
+ if not isinstance(col, str):
+ raise ValueError("col must be a string, but got %r" % type(col))
+ if not isinstance(fractions, dict):
+ raise ValueError("fractions must be a dict but got %r" % type(fractions))
+ for k, v in fractions.items():
+ if not isinstance(k, (float, int, long, basestring)):
+ raise ValueError("key must be float, int, long, or string, but got %r" % type(k))
+ fractions[k] = float(v)
+ seed = seed if seed is not None else random.randint(0, sys.maxsize)
+ return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx)
+
@since(1.4)
def randomSplit(self, weights, seed=None):
"""Randomly splits this :class:`DataFrame` with the provided weights.
@@ -1322,6 +1357,11 @@ class DataFrameStatFunctions(object):
freqItems.__doc__ = DataFrame.freqItems.__doc__
+ def sampleBy(self, col, fractions, seed=None):
+ return self.df.sampleBy(col, fractions, seed)
+
+ sampleBy.__doc__ = DataFrame.sampleBy.__doc__
+
def _test():
import doctest