diff options
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala) | 18 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala | 32 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala | 34 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala (renamed from mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala) | 6 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala | 8 |
5 files changed, 52 insertions, 46 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 7ecb409c4a..9cab49f6ed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -25,21 +25,21 @@ import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} /** * :: Experimental :: - * Trait for random number generators that generate i.i.d. values from a distribution. + * Trait for random data generators that generate i.i.d. data. */ @Experimental -trait DistributionGenerator extends Pseudorandom with Serializable { +trait RandomDataGenerator[T] extends Pseudorandom with Serializable { /** - * Returns an i.i.d. sample as a Double from an underlying distribution. + * Returns an i.i.d. sample as a generic type from an underlying distribution. */ - def nextValue(): Double + def nextValue(): T /** - * Returns a copy of the DistributionGenerator with a new instance of the rng object used in the + * Returns a copy of the RandomDataGenerator with a new instance of the rng object used in the * class when applicable for non-locking concurrent usage. */ - def copy(): DistributionGenerator + def copy(): RandomDataGenerator[T] } /** @@ -47,7 +47,7 @@ trait DistributionGenerator extends Pseudorandom with Serializable { * Generates i.i.d. samples from U[0.0, 1.0] */ @Experimental -class UniformGenerator extends DistributionGenerator { +class UniformGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. private val random = new XORShiftRandom() @@ -66,7 +66,7 @@ class UniformGenerator extends DistributionGenerator { * Generates i.i.d. samples from the standard normal distribution. */ @Experimental -class StandardNormalGenerator extends DistributionGenerator { +class StandardNormalGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. private val random = new XORShiftRandom() @@ -87,7 +87,7 @@ class StandardNormalGenerator extends DistributionGenerator { * @param mean mean for the Poisson distribution. */ @Experimental -class PoissonGenerator(val mean: Double) extends DistributionGenerator { +class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { private var rng = new Poisson(mean, new DRand) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala index 021d651d4d..b0a0593223 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala @@ -24,6 +24,8 @@ import org.apache.spark.mllib.rdd.{RandomVectorRDD, RandomRDD} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils +import scala.reflect.ClassTag + /** * :: Experimental :: * Generator methods for creating RDDs comprised of i.i.d. samples from some distribution. @@ -200,12 +202,12 @@ object RandomRDDGenerators { * @return RDD[Double] comprised of i.i.d. samples produced by generator. */ @Experimental - def randomRDD(sc: SparkContext, - generator: DistributionGenerator, + def randomRDD[T: ClassTag](sc: SparkContext, + generator: RandomDataGenerator[T], size: Long, numPartitions: Int, - seed: Long): RDD[Double] = { - new RandomRDD(sc, size, numPartitions, generator, seed) + seed: Long): RDD[T] = { + new RandomRDD[T](sc, size, numPartitions, generator, seed) } /** @@ -219,11 +221,11 @@ object RandomRDDGenerators { * @return RDD[Double] comprised of i.i.d. samples produced by generator. */ @Experimental - def randomRDD(sc: SparkContext, - generator: DistributionGenerator, + def randomRDD[T: ClassTag](sc: SparkContext, + generator: RandomDataGenerator[T], size: Long, - numPartitions: Int): RDD[Double] = { - randomRDD(sc, generator, size, numPartitions, Utils.random.nextLong) + numPartitions: Int): RDD[T] = { + randomRDD[T](sc, generator, size, numPartitions, Utils.random.nextLong) } /** @@ -237,10 +239,10 @@ object RandomRDDGenerators { * @return RDD[Double] comprised of i.i.d. samples produced by generator. */ @Experimental - def randomRDD(sc: SparkContext, - generator: DistributionGenerator, - size: Long): RDD[Double] = { - randomRDD(sc, generator, size, sc.defaultParallelism, Utils.random.nextLong) + def randomRDD[T: ClassTag](sc: SparkContext, + generator: RandomDataGenerator[T], + size: Long): RDD[T] = { + randomRDD[T](sc, generator, size, sc.defaultParallelism, Utils.random.nextLong) } // TODO Generate RDD[Vector] from multivariate distributions. @@ -439,7 +441,7 @@ object RandomRDDGenerators { */ @Experimental def randomVectorRDD(sc: SparkContext, - generator: DistributionGenerator, + generator: RandomDataGenerator[Double], numRows: Long, numCols: Int, numPartitions: Int, @@ -461,7 +463,7 @@ object RandomRDDGenerators { */ @Experimental def randomVectorRDD(sc: SparkContext, - generator: DistributionGenerator, + generator: RandomDataGenerator[Double], numRows: Long, numCols: Int, numPartitions: Int): RDD[Vector] = { @@ -482,7 +484,7 @@ object RandomRDDGenerators { */ @Experimental def randomVectorRDD(sc: SparkContext, - generator: DistributionGenerator, + generator: RandomDataGenerator[Double], numRows: Long, numCols: Int): RDD[Vector] = { randomVectorRDD(sc, generator, numRows, numCols, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala index f13282d07f..c8db3910c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala @@ -19,35 +19,36 @@ package org.apache.spark.mllib.rdd import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.mllib.linalg.{DenseVector, Vector} -import org.apache.spark.mllib.random.DistributionGenerator +import org.apache.spark.mllib.random.RandomDataGenerator import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils +import scala.reflect.ClassTag import scala.util.Random -private[mllib] class RandomRDDPartition(override val index: Int, +private[mllib] class RandomRDDPartition[T](override val index: Int, val size: Int, - val generator: DistributionGenerator, + val generator: RandomDataGenerator[T], val seed: Long) extends Partition { require(size >= 0, "Non-negative partition size required.") } // These two classes are necessary since Range objects in Scala cannot have size > Int.MaxValue -private[mllib] class RandomRDD(@transient sc: SparkContext, +private[mllib] class RandomRDD[T: ClassTag](@transient sc: SparkContext, size: Long, numPartitions: Int, - @transient rng: DistributionGenerator, - @transient seed: Long = Utils.random.nextLong) extends RDD[Double](sc, Nil) { + @transient rng: RandomDataGenerator[T], + @transient seed: Long = Utils.random.nextLong) extends RDD[T](sc, Nil) { require(size > 0, "Positive RDD size required.") require(numPartitions > 0, "Positive number of partitions required") require(math.ceil(size.toDouble / numPartitions) <= Int.MaxValue, "Partition size cannot exceed Int.MaxValue") - override def compute(splitIn: Partition, context: TaskContext): Iterator[Double] = { - val split = splitIn.asInstanceOf[RandomRDDPartition] - RandomRDD.getPointIterator(split) + override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = { + val split = splitIn.asInstanceOf[RandomRDDPartition[T]] + RandomRDD.getPointIterator[T](split) } override def getPartitions: Array[Partition] = { @@ -59,7 +60,7 @@ private[mllib] class RandomVectorRDD(@transient sc: SparkContext, size: Long, vectorSize: Int, numPartitions: Int, - @transient rng: DistributionGenerator, + @transient rng: RandomDataGenerator[Double], @transient seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, Nil) { require(size > 0, "Positive RDD size required.") @@ -69,7 +70,7 @@ private[mllib] class RandomVectorRDD(@transient sc: SparkContext, "Partition size cannot exceed Int.MaxValue") override def compute(splitIn: Partition, context: TaskContext): Iterator[Vector] = { - val split = splitIn.asInstanceOf[RandomRDDPartition] + val split = splitIn.asInstanceOf[RandomRDDPartition[Double]] RandomRDD.getVectorIterator(split, vectorSize) } @@ -80,12 +81,12 @@ private[mllib] class RandomVectorRDD(@transient sc: SparkContext, private[mllib] object RandomRDD { - def getPartitions(size: Long, + def getPartitions[T](size: Long, numPartitions: Int, - rng: DistributionGenerator, + rng: RandomDataGenerator[T], seed: Long): Array[Partition] = { - val partitions = new Array[RandomRDDPartition](numPartitions) + val partitions = new Array[RandomRDDPartition[T]](numPartitions) var i = 0 var start: Long = 0 var end: Long = 0 @@ -101,7 +102,7 @@ private[mllib] object RandomRDD { // The RNG has to be reset every time the iterator is requested to guarantee same data // every time the content of the RDD is examined. - def getPointIterator(partition: RandomRDDPartition): Iterator[Double] = { + def getPointIterator[T: ClassTag](partition: RandomRDDPartition[T]): Iterator[T] = { val generator = partition.generator.copy() generator.setSeed(partition.seed) Array.fill(partition.size)(generator.nextValue()).toIterator @@ -109,7 +110,8 @@ private[mllib] object RandomRDD { // The RNG has to be reset every time the iterator is requested to guarantee same data // every time the content of the RDD is examined. - def getVectorIterator(partition: RandomRDDPartition, vectorSize: Int): Iterator[Vector] = { + def getVectorIterator(partition: RandomRDDPartition[Double], + vectorSize: Int): Iterator[Vector] = { val generator = partition.generator.copy() generator.setSeed(partition.seed) Array.fill(partition.size)(new DenseVector( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala index 974dec4c0b..3df7c128af 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala @@ -22,9 +22,9 @@ import org.scalatest.FunSuite import org.apache.spark.util.StatCounter // TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged -class DistributionGeneratorSuite extends FunSuite { +class RandomDataGeneratorSuite extends FunSuite { - def apiChecks(gen: DistributionGenerator) { + def apiChecks(gen: RandomDataGenerator[Double]) { // resetting seed should generate the same sequence of random numbers gen.setSeed(42L) @@ -53,7 +53,7 @@ class DistributionGeneratorSuite extends FunSuite { assert(array5.equals(array6)) } - def distributionChecks(gen: DistributionGenerator, + def distributionChecks(gen: RandomDataGenerator[Double], mean: Double = 0.0, stddev: Double = 1.0, epsilon: Double = 0.01) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala index 6aa4f803df..96e0bc63b0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala @@ -78,7 +78,9 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri assert(rdd.partitions.size === numPartitions) // check that partition sizes are balanced - val partSizes = rdd.partitions.map(p => p.asInstanceOf[RandomRDDPartition].size.toDouble) + val partSizes = rdd.partitions.map(p => + p.asInstanceOf[RandomRDDPartition[Double]].size.toDouble) + val partStats = new StatCounter(partSizes) assert(partStats.max - partStats.min <= 1) } @@ -89,7 +91,7 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri val rdd = new RandomRDD(sc, size, numPartitions, new UniformGenerator, 0L) assert(rdd.partitions.size === numPartitions) val count = rdd.partitions.foldLeft(0L) { (count, part) => - count + part.asInstanceOf[RandomRDDPartition].size + count + part.asInstanceOf[RandomRDDPartition[Double]].size } assert(count === size) @@ -145,7 +147,7 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri } } -private[random] class MockDistro extends DistributionGenerator { +private[random] class MockDistro extends RandomDataGenerator[Double] { var seed = 0L |