diff options
3 files changed, 2 insertions, 74 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 213338dfe5..152b87351d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -448,41 +448,6 @@ class DataFrame(object): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) - @since(1.5) - def sampleBy(self, col, fractions, seed=None): - """ - Returns a stratified sample without replacement based on the - fraction given on each stratum. - - :param col: column that defines strata - :param fractions: - sampling fraction for each stratum. If a stratum is not - specified, we treat its fraction as zero. - :param seed: random seed - :return: a new DataFrame that represents the stratified sample - - >>> from pyspark.sql.functions import col - >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key")) - >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0) - >>> sampled.groupBy("key").count().orderBy("key").show() - +---+-----+ - |key|count| - +---+-----+ - | 0| 5| - | 1| 8| - +---+-----+ - """ - if not isinstance(col, str): - raise ValueError("col must be a string, but got %r" % type(col)) - if not isinstance(fractions, dict): - raise ValueError("fractions must be a dict but got %r" % type(fractions)) - for k, v in fractions.items(): - if not isinstance(k, (float, int, long, basestring)): - raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) - fractions[k] = float(v) - seed = seed if seed is not None else random.randint(0, sys.maxsize) - return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) - @since(1.4) def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. @@ -1357,11 +1322,6 @@ class DataFrameStatFunctions(object): freqItems.__doc__ = DataFrame.freqItems.__doc__ - def sampleBy(self, col, fractions, seed=None): - return self.df.sampleBy(col, fractions, seed) - - sampleBy.__doc__ = DataFrame.sampleBy.__doc__ - def _test(): import doctest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 955d28771b..edb9ed7bba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import java.util.UUID - import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.stat._ @@ -165,26 +163,4 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def freqItems(cols: Seq[String]): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, 0.01) } - - /** - * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col column that defines strata - * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat - * its fraction as zero. - * @param seed random seed - * @return a new [[DataFrame]] that represents the stratified sample - * - * @since 1.5.0 - */ - def sampleBy(col: String, fractions: Map[Any, Double], seed: Long): DataFrame = { - require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), - s"Fractions must be in [0, 1], but got $fractions.") - import org.apache.spark.sql.functions.rand - val c = Column(col) - val r = rand(seed).as("rand_" + UUID.randomUUID().toString.take(8)) - val expr = fractions.toSeq.map { case (k, v) => - (c === k) && (r < v) - }.reduce(_ || _) || false - df.filter(expr) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 3dd4688912..0d3ff899da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import org.scalatest.Matchers._ -import org.apache.spark.sql.functions.col +import org.apache.spark.SparkFunSuite -class DataFrameStatSuite extends QueryTest { +class DataFrameStatSuite extends SparkFunSuite { private val sqlCtx = org.apache.spark.sql.test.TestSQLContext import sqlCtx.implicits._ @@ -98,12 +98,4 @@ class DataFrameStatSuite extends QueryTest { val items2 = singleColResults.collect().head items2.getSeq[Double](0) should contain (-1.0) } - - test("sampleBy") { - val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) - val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) - checkAnswer( - sampled.groupBy("key").count().orderBy("key"), - Seq(Row(0, 4), Row(1, 9))) - } } |