aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-07-30 17:16:03 -0700
committerReynold Xin <rxin@databricks.com>2015-07-30 17:16:03 -0700
commitdf32669514afc0223ecdeca30fbfbe0b40baef3a (patch)
treea23fde19657010f2245a72ac04450b8d33fe07b7 /sql
parentca71cc8c8b2d64b7756ae697c06876cd18b536dc (diff)
downloadspark-df32669514afc0223ecdeca30fbfbe0b40baef3a.tar.gz
spark-df32669514afc0223ecdeca30fbfbe0b40baef3a.tar.bz2
spark-df32669514afc0223ecdeca30fbfbe0b40baef3a.zip
[SPARK-7157][SQL] add sampleBy to DataFrame
This was previously committed but then reverted due to test failures (see #6769). Author: Xiangrui Meng <meng@databricks.com> Closes #7755 from rxin/SPARK-7157 and squashes the following commits: fbf9044 [Xiangrui Meng] fix python test 542bd37 [Xiangrui Meng] update test 604fe6d [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 f051afd [Xiangrui Meng] use udf instead of building expression f4e9425 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 8fb990b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 103beb3 [Xiangrui Meng] add Java-friendly sampleBy 991f26f [Xiangrui Meng] fix seed 4a14834 [Xiangrui Meng] move sampleBy to stat 832f7cc [Xiangrui Meng] add sampleBy to DataFrame
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala42
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala12
3 files changed, 61 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 4ec58082e7..2e68e358f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -17,6 +17,10 @@
package org.apache.spark.sql
+import java.{util => ju, lang => jl}
+
+import scala.collection.JavaConverters._
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.stat._
@@ -166,4 +170,42 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
def freqItems(cols: Seq[String]): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, 0.01)
}
+
+ /**
+ * Returns a stratified sample without replacement based on the fraction given on each stratum.
+ * @param col column that defines strata
+ * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
+ * its fraction as zero.
+ * @param seed random seed
+ * @tparam T stratum type
+ * @return a new [[DataFrame]] that represents the stratified sample
+ *
+ * @since 1.5.0
+ */
+ def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = {
+ require(fractions.values.forall(p => p >= 0.0 && p <= 1.0),
+ s"Fractions must be in [0, 1], but got $fractions.")
+ import org.apache.spark.sql.functions.{rand, udf}
+ val c = Column(col)
+ val r = rand(seed)
+ val f = udf { (stratum: Any, x: Double) =>
+ x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0)
+ }
+ df.filter(f(c, r))
+ }
+
+ /**
+ * Returns a stratified sample without replacement based on the fraction given on each stratum.
+ * @param col column that defines strata
+ * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
+ * its fraction as zero.
+ * @param seed random seed
+ * @tparam T stratum type
+ * @return a new [[DataFrame]] that represents the stratified sample
+ *
+ * @since 1.5.0
+ */
+ def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
+ sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
+ }
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 9e61d06f40..2c669bb59a 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -226,4 +226,13 @@ public class JavaDataFrameSuite {
Double result = df.stat().cov("a", "b");
Assert.assertTrue(Math.abs(result) < 1e-6);
}
+
+ @Test
+ public void testSampleBy() {
+ DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key"));
+ DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
+ Row[] actual = sampled.groupBy("key").count().orderBy("key").collect();
+ Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)};
+ Assert.assertArrayEquals(expected, actual);
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 7ba4ba73e0..07a675e64f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -21,9 +21,9 @@ import java.util.Random
import org.scalatest.Matchers._
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.functions.col
-class DataFrameStatSuite extends SparkFunSuite {
+class DataFrameStatSuite extends QueryTest {
private val sqlCtx = org.apache.spark.sql.test.TestSQLContext
import sqlCtx.implicits._
@@ -130,4 +130,12 @@ class DataFrameStatSuite extends SparkFunSuite {
val items2 = singleColResults.collect().head
items2.getSeq[Double](0) should contain (-1.0)
}
+
+ test("sampleBy") {
+ val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key"))
+ val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
+ checkAnswer(
+ sampled.groupBy("key").count().orderBy("key"),
+ Seq(Row(0, 5), Row(1, 8)))
+ }
}