aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorTakeshi YAMAMURO <linguin.m.s@gmail.com>2016-08-27 08:42:41 +0100
committerSean Owen <sowen@cloudera.com>2016-08-27 08:42:41 +0100
commitcd0ed31ea9965563a9b1ea3e8bfbeaf8347cacd9 (patch)
tree6cbe696a14d6bf8aadf9e2ebcd75a089fd3e998c /sql
parent718b6bad2d698b76be6906d51da13626e9f3890e (diff)
downloadspark-cd0ed31ea9965563a9b1ea3e8bfbeaf8347cacd9.tar.gz
spark-cd0ed31ea9965563a9b1ea3e8bfbeaf8347cacd9.tar.bz2
spark-cd0ed31ea9965563a9b1ea3e8bfbeaf8347cacd9.zip
[SPARK-15382][SQL] Fix a bug in sampling with replacement
## What changes were proposed in this pull request? This pr to fix a bug below in sampling with replacement ``` val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b") df.sample(true, 2.0).withColumn("c", monotonically_increasing_id).select($"c").show +---+ | c| +---+ | 0| | 1| | 1| | 1| | 2| +---+ ``` ## How was this patch tested? Added a test in `DataFrameSuite`. Author: Takeshi YAMAMURO <linguin.m.s@gmail.com> Closes #14800 from maropu/FixSampleBug.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala7
2 files changed, 8 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 3562083b06..dd78a78491 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -266,6 +266,7 @@ case class SampleExec(
if (withReplacement) {
val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
val initSampler = ctx.freshName("initSampler")
+ ctx.copyResult = true
ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
s"$initSampler();")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index cd485770d2..ce0b92a461 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1579,4 +1579,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df = spark.createDataFrame(rdd, StructType(schemas), false)
assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100)
}
+
+ test("copy results for sampling with replacement") {
+ val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b")
+ val sampleDf = df.sample(true, 2.00)
+ val d = sampleDf.withColumn("c", monotonically_increasing_id).select($"c").collect
+ assert(d.size == d.distinct.size)
+ }
}