aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSameer Agarwal <sameer@databricks.com>2016-01-07 10:37:15 -0800
committerReynold Xin <rxin@databricks.com>2016-01-07 10:37:15 -0800
commitf194d9911a93fc3a78be820096d4836f22d09976 (patch)
tree81d648d0fe180ef4aa657d889529e48aae422a01
parent592f64985d0d58b4f6a0366bf975e04ca496bdbe (diff)
downloadspark-f194d9911a93fc3a78be820096d4836f22d09976.tar.gz
spark-f194d9911a93fc3a78be820096d4836f22d09976.tar.bz2
spark-f194d9911a93fc3a78be820096d4836f22d09976.zip
[SPARK-12662][SQL] Fix DataFrame.randomSplit to avoid creating overlapping splits
https://issues.apache.org/jira/browse/SPARK-12662 cc yhuai Author: Sameer Agarwal <sameer@databricks.com> Closes #10626 from sameeragarwal/randomsplit.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala22
2 files changed, 28 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 7cf2818590..60d2f05b86 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1062,10 +1062,15 @@ class DataFrame private[sql](
* @since 1.4.0
*/
def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
+ // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
+ // constituent partitions each time a split is materialized which could result in
+ // overlapping splits. To prevent this, we explicitly sort each input partition to make the
+ // ordering deterministic.
+ val sorted = Sort(logicalPlan.output.map(SortOrder(_, Ascending)), global = false, logicalPlan)
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
- new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, logicalPlan))
+ new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted))
}.toArray
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index b15af42caa..63ad6c439a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -62,6 +62,28 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
}
}
+ test("randomSplit on reordered partitions") {
+ // This test ensures that randomSplit does not create overlapping splits even when the
+ // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
+ // rows in each partition.
+ val data =
+ sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id")
+ val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
+
+ assert(splits.length == 2, "wrong number of splits")
+
+ // Verify that the splits span the entire dataset
+ assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)
+
+ // Verify that the splits don't overalap
+ assert(splits(0).intersect(splits(1)).collect().isEmpty)
+
+ // Verify that the results are deterministic across multiple runs
+ val firstRun = splits.toSeq.map(_.collect().toSeq)
+ val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
+ assert(firstRun == secondRun)
+ }
+
test("pearson correlation") {
val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
val corr1 = df.stat.corr("a", "b", "pearson")