diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-07-30 17:16:03 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-07-30 17:16:03 -0700 |
commit | df32669514afc0223ecdeca30fbfbe0b40baef3a (patch) | |
tree | a23fde19657010f2245a72ac04450b8d33fe07b7 /python/pyspark/sql | |
parent | ca71cc8c8b2d64b7756ae697c06876cd18b536dc (diff) | |
download | spark-df32669514afc0223ecdeca30fbfbe0b40baef3a.tar.gz spark-df32669514afc0223ecdeca30fbfbe0b40baef3a.tar.bz2 spark-df32669514afc0223ecdeca30fbfbe0b40baef3a.zip |
[SPARK-7157][SQL] add sampleBy to DataFrame
This was previously committed but then reverted due to test failures (see #6769).
Author: Xiangrui Meng <meng@databricks.com>
Closes #7755 from rxin/SPARK-7157 and squashes the following commits:
fbf9044 [Xiangrui Meng] fix python test
542bd37 [Xiangrui Meng] update test
604fe6d [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157
f051afd [Xiangrui Meng] use udf instead of building expression
f4e9425 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157
8fb990b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157
103beb3 [Xiangrui Meng] add Java-friendly sampleBy
991f26f [Xiangrui Meng] fix seed
4a14834 [Xiangrui Meng] move sampleBy to stat
832f7cc [Xiangrui Meng] add sampleBy to DataFrame
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r-- | python/pyspark/sql/dataframe.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d76e051bd7..0f3480c239 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -441,6 +441,42 @@ class DataFrame(object): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) + @since(1.5) + def sampleBy(self, col, fractions, seed=None): + """ + Returns a stratified sample without replacement based on the + fraction given on each stratum. + + :param col: column that defines strata + :param fractions: + sampling fraction for each stratum. If a stratum is not + specified, we treat its fraction as zero. + :param seed: random seed + :return: a new DataFrame that represents the stratified sample + + >>> from pyspark.sql.functions import col + >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key")) + >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0) + >>> sampled.groupBy("key").count().orderBy("key").show() + +---+-----+ + |key|count| + +---+-----+ + | 0| 3| + | 1| 8| + +---+-----+ + + """ + if not isinstance(col, str): + raise ValueError("col must be a string, but got %r" % type(col)) + if not isinstance(fractions, dict): + raise ValueError("fractions must be a dict but got %r" % type(fractions)) + for k, v in fractions.items(): + if not isinstance(k, (float, int, long, basestring)): + raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) + fractions[k] = float(v) + seed = seed if seed is not None else random.randint(0, sys.maxsize) + return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) + @since(1.4) def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. @@ -1314,6 +1350,11 @@ class DataFrameStatFunctions(object): freqItems.__doc__ = DataFrame.freqItems.__doc__ + def sampleBy(self, col, fractions, seed=None): + return self.df.sampleBy(col, fractions, seed) + + sampleBy.__doc__ = DataFrame.sampleBy.__doc__ + def _test(): import doctest |