diff options
author | Takeshi YAMAMURO <linguin.m.s@gmail.com> | 2016-08-27 08:42:41 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-08-27 08:42:41 +0100 |
commit | cd0ed31ea9965563a9b1ea3e8bfbeaf8347cacd9 (patch) | |
tree | 6cbe696a14d6bf8aadf9e2ebcd75a089fd3e998c /sql/core | |
parent | 718b6bad2d698b76be6906d51da13626e9f3890e (diff) | |
download | spark-cd0ed31ea9965563a9b1ea3e8bfbeaf8347cacd9.tar.gz spark-cd0ed31ea9965563a9b1ea3e8bfbeaf8347cacd9.tar.bz2 spark-cd0ed31ea9965563a9b1ea3e8bfbeaf8347cacd9.zip |
[SPARK-15382][SQL] Fix a bug in sampling with replacement
## What changes were proposed in this pull request?
This pr to fix a bug below in sampling with replacement
```
val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b")
df.sample(true, 2.0).withColumn("c", monotonically_increasing_id).select($"c").show
+---+
| c|
+---+
| 0|
| 1|
| 1|
| 1|
| 2|
+---+
```
## How was this patch tested?
Added a test in `DataFrameSuite`.
Author: Takeshi YAMAMURO <linguin.m.s@gmail.com>
Closes #14800 from maropu/FixSampleBug.
Diffstat (limited to 'sql/core')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala | 1 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 7 |
2 files changed, 8 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 3562083b06..dd78a78491 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -266,6 +266,7 @@ case class SampleExec( if (withReplacement) { val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") + ctx.copyResult = true ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler, s"$initSampler();") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index cd485770d2..ce0b92a461 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1579,4 +1579,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.createDataFrame(rdd, StructType(schemas), false) assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100) } + + test("copy results for sampling with replacement") { + val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b") + val sampleDf = df.sample(true, 2.00) + val d = sampleDf.withColumn("c", monotonically_increasing_id).select($"c").collect + assert(d.size == d.distinct.size) + } } |