aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-04-29 15:34:05 -0700
committerReynold Xin <rxin@databricks.com>2015-04-29 15:34:05 -0700
commitd7dbce8f7da8a7fd01df6633a6043f51161b7d18 (patch)
tree6e57f1d7614527f4796071a541171cd07f8d98d2 /core
parentc9d530e2e5123dbd4fd13fc487c890d6076b24bf (diff)
downloadspark-d7dbce8f7da8a7fd01df6633a6043f51161b7d18.tar.gz
spark-d7dbce8f7da8a7fd01df6633a6043f51161b7d18.tar.bz2
spark-d7dbce8f7da8a7fd01df6633a6043f51161b7d18.zip
[SPARK-7156][SQL] support RandomSplit in DataFrames
This is built on top of kaka1992 's PR #5711 using Logical plans. Author: Burak Yavuz <brkyvz@gmail.com> Closes #5761 from brkyvz/random-sample and squashes the following commits: a1fb0aa [Burak Yavuz] remove unrelated file 69669c3 [Burak Yavuz] fix broken test 1ddb3da [Burak Yavuz] copy base 6000328 [Burak Yavuz] added python api and fixed test 3c11d1b [Burak Yavuz] fixed broken test f400ade [Burak Yavuz] fix build errors 2384266 [Burak Yavuz] addressed comments v0.1 e98ebac [Burak Yavuz] [SPARK-7156][SQL] support RandomSplit in DataFrames
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala19
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java8
2 files changed, 21 insertions, 6 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index d80d94a588..330255f892 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -407,12 +407,27 @@ abstract class RDD[T: ClassTag](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
- new PartitionwiseSampledRDD[T, T](
- this, new BernoulliCellSampler[T](x(0), x(1)), true, seed)
+ randomSampleWithRange(x(0), x(1), seed)
}.toArray
}
/**
+ * Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability
+ * range.
+ * @param lb lower bound to use for the Bernoulli sampler
+ * @param ub upper bound to use for the Bernoulli sampler
+ * @param seed the seed for the Random number generator
+ * @return A random sub-sample of the RDD without replacement.
+ */
+ private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = {
+ this.mapPartitionsWithIndex { case (index, partition) =>
+ val sampler = new BernoulliCellSampler[T](lb, ub)
+ sampler.setSeed(seed + index)
+ sampler.sample(partition)
+ }
+ }
+
+ /**
* Return a fixed-size sampled subset of this RDD in an array
*
* @param withReplacement whether sampling is done with replacement
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 34ac9361d4..c2089b0e56 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -157,11 +157,11 @@ public class JavaAPISuite implements Serializable {
public void randomSplit() {
List<Integer> ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
JavaRDD<Integer> rdd = sc.parallelize(ints);
- JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 11);
+ JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31);
Assert.assertEquals(3, splits.length);
- Assert.assertEquals(2, splits[0].count());
- Assert.assertEquals(3, splits[1].count());
- Assert.assertEquals(5, splits[2].count());
+ Assert.assertEquals(1, splits[0].count());
+ Assert.assertEquals(2, splits[1].count());
+ Assert.assertEquals(7, splits[2].count());
}
@Test