aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/dataframe.py40
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala24
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala12
3 files changed, 74 insertions, 2 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 152b87351d..213338dfe5 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -448,6 +448,41 @@ class DataFrame(object):
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)
+ @since(1.5)
+ def sampleBy(self, col, fractions, seed=None):
+ """
+ Returns a stratified sample without replacement based on the
+ fraction given on each stratum.
+
+ :param col: column that defines strata
+ :param fractions:
+ sampling fraction for each stratum. If a stratum is not
+ specified, we treat its fraction as zero.
+ :param seed: random seed
+ :return: a new DataFrame that represents the stratified sample
+
+ >>> from pyspark.sql.functions import col
+ >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key"))
+ >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)
+ >>> sampled.groupBy("key").count().orderBy("key").show()
+ +---+-----+
+ |key|count|
+ +---+-----+
+ | 0| 5|
+ | 1| 8|
+ +---+-----+
+ """
+ if not isinstance(col, str):
+ raise ValueError("col must be a string, but got %r" % type(col))
+ if not isinstance(fractions, dict):
+ raise ValueError("fractions must be a dict but got %r" % type(fractions))
+ for k, v in fractions.items():
+ if not isinstance(k, (float, int, long, basestring)):
+ raise ValueError("key must be float, int, long, or string, but got %r" % type(k))
+ fractions[k] = float(v)
+ seed = seed if seed is not None else random.randint(0, sys.maxsize)
+ return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx)
+
@since(1.4)
def randomSplit(self, weights, seed=None):
"""Randomly splits this :class:`DataFrame` with the provided weights.
@@ -1322,6 +1357,11 @@ class DataFrameStatFunctions(object):
freqItems.__doc__ = DataFrame.freqItems.__doc__
+ def sampleBy(self, col, fractions, seed=None):
+ return self.df.sampleBy(col, fractions, seed)
+
+ sampleBy.__doc__ = DataFrame.sampleBy.__doc__
+
def _test():
import doctest
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index edb9ed7bba..955d28771b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import java.util.UUID
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.stat._
@@ -163,4 +165,26 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
def freqItems(cols: Seq[String]): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, 0.01)
}
+
+ /**
+ * Returns a stratified sample without replacement based on the fraction given on each stratum.
+ * @param col column that defines strata
+ * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
+ * its fraction as zero.
+ * @param seed random seed
+ * @return a new [[DataFrame]] that represents the stratified sample
+ *
+ * @since 1.5.0
+ */
+ def sampleBy(col: String, fractions: Map[Any, Double], seed: Long): DataFrame = {
+ require(fractions.values.forall(p => p >= 0.0 && p <= 1.0),
+ s"Fractions must be in [0, 1], but got $fractions.")
+ import org.apache.spark.sql.functions.rand
+ val c = Column(col)
+ val r = rand(seed).as("rand_" + UUID.randomUUID().toString.take(8))
+ val expr = fractions.toSeq.map { case (k, v) =>
+ (c === k) && (r < v)
+ }.reduce(_ || _) || false
+ df.filter(expr)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 0d3ff899da..3dd4688912 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -19,9 +19,9 @@ package org.apache.spark.sql
import org.scalatest.Matchers._
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.functions.col
-class DataFrameStatSuite extends SparkFunSuite {
+class DataFrameStatSuite extends QueryTest {
private val sqlCtx = org.apache.spark.sql.test.TestSQLContext
import sqlCtx.implicits._
@@ -98,4 +98,12 @@ class DataFrameStatSuite extends SparkFunSuite {
val items2 = singleColResults.collect().head
items2.getSeq[Double](0) should contain (-1.0)
}
+
+ test("sampleBy") {
+ val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key"))
+ val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
+ checkAnswer(
+ sampled.groupBy("key").count().orderBy("key"),
+ Seq(Row(0, 4), Row(1, 9)))
+ }
}