aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala7
2 files changed, 8 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 3562083b06..dd78a78491 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -266,6 +266,7 @@ case class SampleExec(
if (withReplacement) {
val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
val initSampler = ctx.freshName("initSampler")
+ ctx.copyResult = true
ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
s"$initSampler();")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index cd485770d2..ce0b92a461 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1579,4 +1579,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df = spark.createDataFrame(rdd, StructType(schemas), false)
assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100)
}
+
+ test("copy results for sampling with replacement") {
+ val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b")
+ val sampleDf = df.sample(true, 2.00)
+ val d = sampleDf.withColumn("c", monotonically_increasing_id).select($"c").collect
+ assert(d.size == d.distinct.size)
+ }
}