diff options
author | Arun Ramakrishnan <smartnut007@gmail.com> | 2014-04-24 17:27:16 -0700 |
---|---|---|
committer | Matei Zaharia <matei@databricks.com> | 2014-04-24 17:27:16 -0700 |
commit | 35e3d199f04fba3230625002a458d43b9578b2e8 (patch) | |
tree | 7e301a1585e3dc45cd1a42b8ce567b0aada57b4f /core | |
parent | f99af8529b6969986f0c3e03f6ff9b7bb9d53ece (diff) | |
download | spark-35e3d199f04fba3230625002a458d43b9578b2e8.tar.gz spark-35e3d199f04fba3230625002a458d43b9578b2e8.tar.bz2 spark-35e3d199f04fba3230625002a458d43b9578b2e8.zip |
SPARK-1438 RDD.sample() make seed param optional
copying form previous pull request https://github.com/apache/spark/pull/462
Its probably better to let the underlying language implementation take care of the default . This was easier to do with python as the default value for seed in random and numpy random is None.
In Scala/Java side it might mean propagating an Option or null(oh no!) down the chain until where the Random is constructed. But, looks like the convention in some other methods was to use System.nanoTime. So, followed that convention.
Conflict with overloaded method in sql.SchemaRDD.sample which also defines default params.
sample(fraction, withReplacement=false, seed=math.random)
Scala does not allow more than one overloaded to have default params. I believe the author intended to override the RDD.sample method and not overload it. So, changed it.
If backward compatible is important, 3 new method can be introduced (without default params) like this
sample(fraction)
sample(fraction, withReplacement)
sample(fraction, withReplacement, seed)
Added some tests for the scala RDD takeSample method.
Author: Arun Ramakrishnan <smartnut007@gmail.com>
This patch had conflicts when merged, resolved by
Committer: Matei Zaharia <matei@databricks.com>
Closes #477 from smartnut007/master and squashes the following commits:
07bb06e [Arun Ramakrishnan] SPARK-1438 fixing more space formatting issues
b9ebfe2 [Arun Ramakrishnan] SPARK-1438 removing redundant import of random in python rddsampler
8d05b1a [Arun Ramakrishnan] SPARK-1438 RDD . Replace System.nanoTime with a Random generated number. python: use a separate instance of Random instead of seeding language api global Random instance.
69619c6 [Arun Ramakrishnan] SPARK-1438 fix spacing issue
0c247db [Arun Ramakrishnan] SPARK-1438 RDD language apis to support optional seed in RDD methods sample/takeSample
Diffstat (limited to 'core')
8 files changed, 61 insertions, 11 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 4330cef396..a6123bd108 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -30,6 +30,7 @@ import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.StatCounter +import org.apache.spark.util.Utils class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, JavaDoubleRDD] { @@ -133,7 +134,13 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja /** * Return a sampled subset of this RDD. */ - def sample(withReplacement: Boolean, fraction: JDouble, seed: Int): JavaDoubleRDD = + def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD = + sample(withReplacement, fraction, Utils.random.nextLong) + + /** + * Return a sampled subset of this RDD. + */ + def sample(withReplacement: Boolean, fraction: JDouble, seed: Long): JavaDoubleRDD = fromRDD(srdd.sample(withReplacement, fraction, seed)) /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index b3ec270281..554c065358 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -39,6 +39,7 @@ import org.apache.spark.api.java.function.{Function => JFunction, Function2 => J import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.{OrderedRDDFunctions, RDD} import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) (implicit val kClassTag: ClassTag[K], implicit val vClassTag: ClassTag[V]) @@ -119,7 +120,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Return a sampled subset of this RDD. */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] = + def sample(withReplacement: Boolean, fraction: Double): JavaPairRDD[K, V] = + sample(withReplacement, fraction, Utils.random.nextLong) + + /** + * Return a sampled subset of this RDD. + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed)) /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 327c1552dc..dc698dea75 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -24,6 +24,7 @@ import org.apache.spark._ import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) extends JavaRDDLike[T, JavaRDD[T]] { @@ -98,7 +99,13 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) /** * Return a sampled subset of this RDD. */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] = + def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = + sample(withReplacement, fraction, Utils.random.nextLong) + + /** + * Return a sampled subset of this RDD. + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 725c423a53..574a98636a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -34,6 +34,7 @@ import org.apache.spark.api.java.function.{Function => JFunction, Function2 => J import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def wrapRDD(rdd: RDD[T]): This @@ -394,7 +395,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { new java.util.ArrayList(arr) } - def takeSample(withReplacement: Boolean, num: Int, seed: Int): JList[T] = { + def takeSample(withReplacement: Boolean, num: Int): JList[T] = + takeSample(withReplacement, num, Utils.random.nextLong) + + def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = { import scala.collection.JavaConversions._ val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq new java.util.ArrayList(arr) diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index b4e3bb5d75..b5b8a5706d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} import org.apache.spark.util.random.RandomSampler +import org.apache.spark.util.Utils private[spark] class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) @@ -38,14 +39,14 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) * * @param prev RDD to be sampled * @param sampler a random sampler - * @param seed random seed, default to System.nanoTime + * @param seed random seed * @tparam T input RDD item type * @tparam U sampled RDD item type */ private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag]( prev: RDD[T], sampler: RandomSampler[T, U], - @transient seed: Long = System.nanoTime) + @transient seed: Long = Utils.random.nextLong) extends RDD[U](prev) { override def getPartitions: Array[Partition] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 6c897cc03b..e8bbfbf016 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -341,7 +341,9 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = { + def sample(withReplacement: Boolean, + fraction: Double, + seed: Long = Utils.random.nextLong): RDD[T] = { require(fraction >= 0.0, "Invalid fraction value: " + fraction) if (withReplacement) { new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed) @@ -354,11 +356,11 @@ abstract class RDD[T: ClassTag]( * Randomly splits this RDD with the provided weights. * * @param weights weights for splits, will be normalized if they don't sum to 1 - * @param seed random seed, default to System.nanoTime + * @param seed random seed * * @return split RDDs in an array */ - def randomSplit(weights: Array[Double], seed: Long = System.nanoTime): Array[RDD[T]] = { + def randomSplit(weights: Array[Double], seed: Long = Utils.random.nextLong): Array[RDD[T]] = { val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => @@ -366,7 +368,8 @@ abstract class RDD[T: ClassTag]( }.toArray } - def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = { + def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] = + { var fraction = 0.0 var total = 0 val multiplier = 3.0 diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d333e2a88c..084a71c4ca 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -46,6 +46,8 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream, private[spark] object Utils extends Logging { val osName = System.getProperty("os.name") + + val random = new Random() /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 2676558bfc..8da9a0da70 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -463,7 +463,13 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("takeSample") { val data = sc.parallelize(1 to 100, 2) - + + for (num <- List(5, 20, 100)) { + val sample = data.takeSample(withReplacement=false, num=num) + assert(sample.size === num) // Got exactly num elements + assert(sample.toSet.size === num) // Elements are distinct + assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements @@ -481,6 +487,19 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sample.size === 20) // Got exactly 20 elements assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") } + { + val sample = data.takeSample(withReplacement=true, num=20) + assert(sample.size === 20) // Got exactly 100 elements + assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements") + assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + } + { + val sample = data.takeSample(withReplacement=true, num=100) + assert(sample.size === 100) // Got exactly 100 elements + // Chance of getting all distinct elements is astronomically low, so test we got < 100 + assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=true, 100, seed) assert(sample.size === 100) // Got exactly 100 elements |