From df32669514afc0223ecdeca30fbfbe0b40baef3a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 30 Jul 2015 17:16:03 -0700 Subject: [SPARK-7157][SQL] add sampleBy to DataFrame This was previously committed but then reverted due to test failures (see #6769). Author: Xiangrui Meng Closes #7755 from rxin/SPARK-7157 and squashes the following commits: fbf9044 [Xiangrui Meng] fix python test 542bd37 [Xiangrui Meng] update test 604fe6d [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 f051afd [Xiangrui Meng] use udf instead of building expression f4e9425 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 8fb990b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 103beb3 [Xiangrui Meng] add Java-friendly sampleBy 991f26f [Xiangrui Meng] fix seed 4a14834 [Xiangrui Meng] move sampleBy to stat 832f7cc [Xiangrui Meng] add sampleBy to DataFrame --- .../apache/spark/sql/DataFrameStatFunctions.scala | 42 ++++++++++++++++++++++ .../org/apache/spark/sql/JavaDataFrameSuite.java | 9 +++++ .../org/apache/spark/sql/DataFrameStatSuite.scala | 12 +++++-- 3 files changed, 61 insertions(+), 2 deletions(-) (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 4ec58082e7..2e68e358f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import java.{util => ju, lang => jl} + +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.stat._ @@ -166,4 +170,42 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def freqItems(cols: Seq[String]): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, 0.01) } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new [[DataFrame]] that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { + require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), + s"Fractions must be in [0, 1], but got $fractions.") + import org.apache.spark.sql.functions.{rand, udf} + val c = Column(col) + val r = rand(seed) + val f = udf { (stratum: Any, x: Double) => + x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0) + } + df.filter(f(c, r)) + } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new [[DataFrame]] that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { + sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 9e61d06f40..2c669bb59a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -226,4 +226,13 @@ public class JavaDataFrameSuite { Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1e-6); } + + @Test + public void testSampleBy() { + DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key")); + DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); + Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; + Assert.assertArrayEquals(expected, actual); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 7ba4ba73e0..07a675e64f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -21,9 +21,9 @@ import java.util.Random import org.scalatest.Matchers._ -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions.col -class DataFrameStatSuite extends SparkFunSuite { +class DataFrameStatSuite extends QueryTest { private val sqlCtx = org.apache.spark.sql.test.TestSQLContext import sqlCtx.implicits._ @@ -130,4 +130,12 @@ class DataFrameStatSuite extends SparkFunSuite { val items2 = singleColResults.collect().head items2.getSeq[Double](0) should contain (-1.0) } + + test("sampleBy") { + val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) + val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) + checkAnswer( + sampled.groupBy("key").count().orderBy("key"), + Seq(Row(0, 5), Row(1, 8))) + } } -- cgit v1.2.3