aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorBurak <brkyvz@gmail.com>2014-08-01 22:32:12 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-01 22:32:12 -0700
commitfda475987f3b8b37d563033b0e45706ce433824a (patch)
tree9dcf73fdbb7e7deb47419efbac387c3fc7ca27ff /mllib
parente25ec06171e3ba95920cbfe9df3cd3d990f1a3a3 (diff)
downloadspark-fda475987f3b8b37d563033b0e45706ce433824a.tar.gz
spark-fda475987f3b8b37d563033b0e45706ce433824a.tar.bz2
spark-fda475987f3b8b37d563033b0e45706ce433824a.zip
[SPARK-2801][MLlib]: DistributionGenerator renamed to RandomDataGenerator. RandomRDD is now of generic type
The RandomRDDGenerators used to only output RDD[Double]. Now RandomRDDGenerators.randomRDD can be used to generate a random RDD[T] via a class that extends RandomDataGenerator, by supplying a type T and overriding the nextValue() function as they wish. Author: Burak <brkyvz@gmail.com> Closes #1732 from brkyvz/SPARK-2801 and squashes the following commits: c94a694 [Burak] [SPARK-2801][MLlib] Missing ClassTags added 22d96fe [Burak] [SPARK-2801][MLlib]: DistributionGenerator renamed to RandomDataGenerator, generic types added for RandomRDD instead of Double
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala)18
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala32
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala34
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala (renamed from mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala)6
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala8
5 files changed, 52 insertions, 46 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
index 7ecb409c4a..9cab49f6ed 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/DistributionGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
@@ -25,21 +25,21 @@ import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom}
/**
* :: Experimental ::
- * Trait for random number generators that generate i.i.d. values from a distribution.
+ * Trait for random data generators that generate i.i.d. data.
*/
@Experimental
-trait DistributionGenerator extends Pseudorandom with Serializable {
+trait RandomDataGenerator[T] extends Pseudorandom with Serializable {
/**
- * Returns an i.i.d. sample as a Double from an underlying distribution.
+ * Returns an i.i.d. sample as a generic type from an underlying distribution.
*/
- def nextValue(): Double
+ def nextValue(): T
/**
- * Returns a copy of the DistributionGenerator with a new instance of the rng object used in the
+ * Returns a copy of the RandomDataGenerator with a new instance of the rng object used in the
* class when applicable for non-locking concurrent usage.
*/
- def copy(): DistributionGenerator
+ def copy(): RandomDataGenerator[T]
}
/**
@@ -47,7 +47,7 @@ trait DistributionGenerator extends Pseudorandom with Serializable {
* Generates i.i.d. samples from U[0.0, 1.0]
*/
@Experimental
-class UniformGenerator extends DistributionGenerator {
+class UniformGenerator extends RandomDataGenerator[Double] {
// XORShiftRandom for better performance. Thread safety isn't necessary here.
private val random = new XORShiftRandom()
@@ -66,7 +66,7 @@ class UniformGenerator extends DistributionGenerator {
* Generates i.i.d. samples from the standard normal distribution.
*/
@Experimental
-class StandardNormalGenerator extends DistributionGenerator {
+class StandardNormalGenerator extends RandomDataGenerator[Double] {
// XORShiftRandom for better performance. Thread safety isn't necessary here.
private val random = new XORShiftRandom()
@@ -87,7 +87,7 @@ class StandardNormalGenerator extends DistributionGenerator {
* @param mean mean for the Poisson distribution.
*/
@Experimental
-class PoissonGenerator(val mean: Double) extends DistributionGenerator {
+class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] {
private var rng = new Poisson(mean, new DRand)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala
index 021d651d4d..b0a0593223 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala
@@ -24,6 +24,8 @@ import org.apache.spark.mllib.rdd.{RandomVectorRDD, RandomRDD}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
+import scala.reflect.ClassTag
+
/**
* :: Experimental ::
* Generator methods for creating RDDs comprised of i.i.d. samples from some distribution.
@@ -200,12 +202,12 @@ object RandomRDDGenerators {
* @return RDD[Double] comprised of i.i.d. samples produced by generator.
*/
@Experimental
- def randomRDD(sc: SparkContext,
- generator: DistributionGenerator,
+ def randomRDD[T: ClassTag](sc: SparkContext,
+ generator: RandomDataGenerator[T],
size: Long,
numPartitions: Int,
- seed: Long): RDD[Double] = {
- new RandomRDD(sc, size, numPartitions, generator, seed)
+ seed: Long): RDD[T] = {
+ new RandomRDD[T](sc, size, numPartitions, generator, seed)
}
/**
@@ -219,11 +221,11 @@ object RandomRDDGenerators {
* @return RDD[Double] comprised of i.i.d. samples produced by generator.
*/
@Experimental
- def randomRDD(sc: SparkContext,
- generator: DistributionGenerator,
+ def randomRDD[T: ClassTag](sc: SparkContext,
+ generator: RandomDataGenerator[T],
size: Long,
- numPartitions: Int): RDD[Double] = {
- randomRDD(sc, generator, size, numPartitions, Utils.random.nextLong)
+ numPartitions: Int): RDD[T] = {
+ randomRDD[T](sc, generator, size, numPartitions, Utils.random.nextLong)
}
/**
@@ -237,10 +239,10 @@ object RandomRDDGenerators {
* @return RDD[Double] comprised of i.i.d. samples produced by generator.
*/
@Experimental
- def randomRDD(sc: SparkContext,
- generator: DistributionGenerator,
- size: Long): RDD[Double] = {
- randomRDD(sc, generator, size, sc.defaultParallelism, Utils.random.nextLong)
+ def randomRDD[T: ClassTag](sc: SparkContext,
+ generator: RandomDataGenerator[T],
+ size: Long): RDD[T] = {
+ randomRDD[T](sc, generator, size, sc.defaultParallelism, Utils.random.nextLong)
}
// TODO Generate RDD[Vector] from multivariate distributions.
@@ -439,7 +441,7 @@ object RandomRDDGenerators {
*/
@Experimental
def randomVectorRDD(sc: SparkContext,
- generator: DistributionGenerator,
+ generator: RandomDataGenerator[Double],
numRows: Long,
numCols: Int,
numPartitions: Int,
@@ -461,7 +463,7 @@ object RandomRDDGenerators {
*/
@Experimental
def randomVectorRDD(sc: SparkContext,
- generator: DistributionGenerator,
+ generator: RandomDataGenerator[Double],
numRows: Long,
numCols: Int,
numPartitions: Int): RDD[Vector] = {
@@ -482,7 +484,7 @@ object RandomRDDGenerators {
*/
@Experimental
def randomVectorRDD(sc: SparkContext,
- generator: DistributionGenerator,
+ generator: RandomDataGenerator[Double],
numRows: Long,
numCols: Int): RDD[Vector] = {
randomVectorRDD(sc, generator, numRows, numCols,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala
index f13282d07f..c8db3910c6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala
@@ -19,35 +19,36 @@ package org.apache.spark.mllib.rdd
import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
-import org.apache.spark.mllib.random.DistributionGenerator
+import org.apache.spark.mllib.random.RandomDataGenerator
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
+import scala.reflect.ClassTag
import scala.util.Random
-private[mllib] class RandomRDDPartition(override val index: Int,
+private[mllib] class RandomRDDPartition[T](override val index: Int,
val size: Int,
- val generator: DistributionGenerator,
+ val generator: RandomDataGenerator[T],
val seed: Long) extends Partition {
require(size >= 0, "Non-negative partition size required.")
}
// These two classes are necessary since Range objects in Scala cannot have size > Int.MaxValue
-private[mllib] class RandomRDD(@transient sc: SparkContext,
+private[mllib] class RandomRDD[T: ClassTag](@transient sc: SparkContext,
size: Long,
numPartitions: Int,
- @transient rng: DistributionGenerator,
- @transient seed: Long = Utils.random.nextLong) extends RDD[Double](sc, Nil) {
+ @transient rng: RandomDataGenerator[T],
+ @transient seed: Long = Utils.random.nextLong) extends RDD[T](sc, Nil) {
require(size > 0, "Positive RDD size required.")
require(numPartitions > 0, "Positive number of partitions required")
require(math.ceil(size.toDouble / numPartitions) <= Int.MaxValue,
"Partition size cannot exceed Int.MaxValue")
- override def compute(splitIn: Partition, context: TaskContext): Iterator[Double] = {
- val split = splitIn.asInstanceOf[RandomRDDPartition]
- RandomRDD.getPointIterator(split)
+ override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
+ val split = splitIn.asInstanceOf[RandomRDDPartition[T]]
+ RandomRDD.getPointIterator[T](split)
}
override def getPartitions: Array[Partition] = {
@@ -59,7 +60,7 @@ private[mllib] class RandomVectorRDD(@transient sc: SparkContext,
size: Long,
vectorSize: Int,
numPartitions: Int,
- @transient rng: DistributionGenerator,
+ @transient rng: RandomDataGenerator[Double],
@transient seed: Long = Utils.random.nextLong) extends RDD[Vector](sc, Nil) {
require(size > 0, "Positive RDD size required.")
@@ -69,7 +70,7 @@ private[mllib] class RandomVectorRDD(@transient sc: SparkContext,
"Partition size cannot exceed Int.MaxValue")
override def compute(splitIn: Partition, context: TaskContext): Iterator[Vector] = {
- val split = splitIn.asInstanceOf[RandomRDDPartition]
+ val split = splitIn.asInstanceOf[RandomRDDPartition[Double]]
RandomRDD.getVectorIterator(split, vectorSize)
}
@@ -80,12 +81,12 @@ private[mllib] class RandomVectorRDD(@transient sc: SparkContext,
private[mllib] object RandomRDD {
- def getPartitions(size: Long,
+ def getPartitions[T](size: Long,
numPartitions: Int,
- rng: DistributionGenerator,
+ rng: RandomDataGenerator[T],
seed: Long): Array[Partition] = {
- val partitions = new Array[RandomRDDPartition](numPartitions)
+ val partitions = new Array[RandomRDDPartition[T]](numPartitions)
var i = 0
var start: Long = 0
var end: Long = 0
@@ -101,7 +102,7 @@ private[mllib] object RandomRDD {
// The RNG has to be reset every time the iterator is requested to guarantee same data
// every time the content of the RDD is examined.
- def getPointIterator(partition: RandomRDDPartition): Iterator[Double] = {
+ def getPointIterator[T: ClassTag](partition: RandomRDDPartition[T]): Iterator[T] = {
val generator = partition.generator.copy()
generator.setSeed(partition.seed)
Array.fill(partition.size)(generator.nextValue()).toIterator
@@ -109,7 +110,8 @@ private[mllib] object RandomRDD {
// The RNG has to be reset every time the iterator is requested to guarantee same data
// every time the content of the RDD is examined.
- def getVectorIterator(partition: RandomRDDPartition, vectorSize: Int): Iterator[Vector] = {
+ def getVectorIterator(partition: RandomRDDPartition[Double],
+ vectorSize: Int): Iterator[Vector] = {
val generator = partition.generator.copy()
generator.setSeed(partition.seed)
Array.fill(partition.size)(new DenseVector(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
index 974dec4c0b..3df7c128af 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/random/DistributionGeneratorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala
@@ -22,9 +22,9 @@ import org.scalatest.FunSuite
import org.apache.spark.util.StatCounter
// TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
-class DistributionGeneratorSuite extends FunSuite {
+class RandomDataGeneratorSuite extends FunSuite {
- def apiChecks(gen: DistributionGenerator) {
+ def apiChecks(gen: RandomDataGenerator[Double]) {
// resetting seed should generate the same sequence of random numbers
gen.setSeed(42L)
@@ -53,7 +53,7 @@ class DistributionGeneratorSuite extends FunSuite {
assert(array5.equals(array6))
}
- def distributionChecks(gen: DistributionGenerator,
+ def distributionChecks(gen: RandomDataGenerator[Double],
mean: Double = 0.0,
stddev: Double = 1.0,
epsilon: Double = 0.01) {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala
index 6aa4f803df..96e0bc63b0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala
@@ -78,7 +78,9 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri
assert(rdd.partitions.size === numPartitions)
// check that partition sizes are balanced
- val partSizes = rdd.partitions.map(p => p.asInstanceOf[RandomRDDPartition].size.toDouble)
+ val partSizes = rdd.partitions.map(p =>
+ p.asInstanceOf[RandomRDDPartition[Double]].size.toDouble)
+
val partStats = new StatCounter(partSizes)
assert(partStats.max - partStats.min <= 1)
}
@@ -89,7 +91,7 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri
val rdd = new RandomRDD(sc, size, numPartitions, new UniformGenerator, 0L)
assert(rdd.partitions.size === numPartitions)
val count = rdd.partitions.foldLeft(0L) { (count, part) =>
- count + part.asInstanceOf[RandomRDDPartition].size
+ count + part.asInstanceOf[RandomRDDPartition[Double]].size
}
assert(count === size)
@@ -145,7 +147,7 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri
}
}
-private[random] class MockDistro extends DistributionGenerator {
+private[random] class MockDistro extends RandomDataGenerator[Double] {
var seed = 0L