diff options
author | Burak Yavuz <brkyvz@gmail.com> | 2015-04-29 15:34:05 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-04-29 15:34:05 -0700 |
commit | d7dbce8f7da8a7fd01df6633a6043f51161b7d18 (patch) | |
tree | 6e57f1d7614527f4796071a541171cd07f8d98d2 /core/src | |
parent | c9d530e2e5123dbd4fd13fc487c890d6076b24bf (diff) | |
download | spark-d7dbce8f7da8a7fd01df6633a6043f51161b7d18.tar.gz spark-d7dbce8f7da8a7fd01df6633a6043f51161b7d18.tar.bz2 spark-d7dbce8f7da8a7fd01df6633a6043f51161b7d18.zip |
[SPARK-7156][SQL] support RandomSplit in DataFrames
This is built on top of kaka1992 's PR #5711 using Logical plans.
Author: Burak Yavuz <brkyvz@gmail.com>
Closes #5761 from brkyvz/random-sample and squashes the following commits:
a1fb0aa [Burak Yavuz] remove unrelated file
69669c3 [Burak Yavuz] fix broken test
1ddb3da [Burak Yavuz] copy base
6000328 [Burak Yavuz] added python api and fixed test
3c11d1b [Burak Yavuz] fixed broken test
f400ade [Burak Yavuz] fix build errors
2384266 [Burak Yavuz] addressed comments v0.1
e98ebac [Burak Yavuz] [SPARK-7156][SQL] support RandomSplit in DataFrames
Diffstat (limited to 'core/src')
-rw-r--r-- | core/src/main/scala/org/apache/spark/rdd/RDD.scala | 19 | ||||
-rw-r--r-- | core/src/test/java/org/apache/spark/JavaAPISuite.java | 8 |
2 files changed, 21 insertions, 6 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index d80d94a588..330255f892 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -407,12 +407,27 @@ abstract class RDD[T: ClassTag]( val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new PartitionwiseSampledRDD[T, T]( - this, new BernoulliCellSampler[T](x(0), x(1)), true, seed) + randomSampleWithRange(x(0), x(1), seed) }.toArray } /** + * Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability + * range. + * @param lb lower bound to use for the Bernoulli sampler + * @param ub upper bound to use for the Bernoulli sampler + * @param seed the seed for the Random number generator + * @return A random sub-sample of the RDD without replacement. + */ + private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = { + this.mapPartitionsWithIndex { case (index, partition) => + val sampler = new BernoulliCellSampler[T](lb, ub) + sampler.setSeed(seed + index) + sampler.sample(partition) + } + } + + /** * Return a fixed-size sampled subset of this RDD in an array * * @param withReplacement whether sampling is done with replacement diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 34ac9361d4..c2089b0e56 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -157,11 +157,11 @@ public class JavaAPISuite implements Serializable { public void randomSplit() { List<Integer> ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); JavaRDD<Integer> rdd = sc.parallelize(ints); - JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 11); + JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31); Assert.assertEquals(3, splits.length); - Assert.assertEquals(2, splits[0].count()); - Assert.assertEquals(3, splits[1].count()); - Assert.assertEquals(5, splits[2].count()); + Assert.assertEquals(1, splits[0].count()); + Assert.assertEquals(2, splits[1].count()); + Assert.assertEquals(7, splits[2].count()); } @Test |