From cd0ed31ea9965563a9b1ea3e8bfbeaf8347cacd9 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Sat, 27 Aug 2016 08:42:41 +0100 Subject: [SPARK-15382][SQL] Fix a bug in sampling with replacement ## What changes were proposed in this pull request? This pr to fix a bug below in sampling with replacement ``` val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b") df.sample(true, 2.0).withColumn("c", monotonically_increasing_id).select($"c").show +---+ | c| +---+ | 0| | 1| | 1| | 1| | 2| +---+ ``` ## How was this patch tested? Added a test in `DataFrameSuite`. Author: Takeshi YAMAMURO Closes #14800 from maropu/FixSampleBug. --- .../org/apache/spark/sql/execution/basicPhysicalOperators.scala | 1 + sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 7 +++++++ 2 files changed, 8 insertions(+) (limited to 'sql/core/src') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 3562083b06..dd78a78491 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -266,6 +266,7 @@ case class SampleExec( if (withReplacement) { val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") + ctx.copyResult = true ctx.addMutableState(s"$samplerClass", sampler, s"$initSampler();") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index cd485770d2..ce0b92a461 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1579,4 +1579,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.createDataFrame(rdd, StructType(schemas), false) assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100) } + + test("copy results for sampling with replacement") { + val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b") + val sampleDf = df.sample(true, 2.00) + val d = sampleDf.withColumn("c", monotonically_increasing_id).select($"c").collect + assert(d.size == d.distinct.size) + } } -- cgit v1.2.3