aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-09-21 18:53:28 +0100
committerSean Owen <sowen@cloudera.com>2015-09-21 18:53:28 +0100
commit20a61dbd9b57957fcc5b58ef8935533914172b07 (patch)
tree6fd0c28146ce7ef52dae0d2991dc736d871ae2ac /mllib/src
parent01440395176bdbb2662480f03b27851cb860f385 (diff)
downloadspark-20a61dbd9b57957fcc5b58ef8935533914172b07.tar.gz
spark-20a61dbd9b57957fcc5b58ef8935533914172b07.tar.bz2
spark-20a61dbd9b57957fcc5b58ef8935533914172b07.zip
[SPARK-10626] [MLLIB] create java friendly method for random rdd
SPARK-3136 added a large number of functions for creating Java RandomRDDs, but for people that want to use custom RandomDataGenerators we should make a Java friendly method. Author: Holden Karau <holden@pigscanfly.ca> Closes #8782 from holdenk/SPARK-10626-create-java-friendly-method-for-randomRDD.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala52
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java30
2 files changed, 81 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
index 4dd5ea214d..f8ff26b579 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala
@@ -22,6 +22,7 @@ import scala.reflect.ClassTag
import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext}
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD}
import org.apache.spark.rdd.RDD
@@ -381,7 +382,7 @@ object RandomRDDs {
* @param size Size of the RDD.
* @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`).
* @param seed Random seed (default: a random long integer).
- * @return RDD[Double] comprised of `i.i.d.` samples produced by generator.
+ * @return RDD[T] comprised of `i.i.d.` samples produced by generator.
*/
@DeveloperApi
@Since("1.1.0")
@@ -394,6 +395,55 @@ object RandomRDDs {
new RandomRDD[T](sc, size, numPartitionsOrDefault(sc, numPartitions), generator, seed)
}
+ /**
+ * :: DeveloperApi ::
+ * Generates an RDD comprised of `i.i.d.` samples produced by the input RandomDataGenerator.
+ *
+ * @param jsc JavaSparkContext used to create the RDD.
+ * @param generator RandomDataGenerator used to populate the RDD.
+ * @param size Size of the RDD.
+ * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`).
+ * @param seed Random seed (default: a random long integer).
+ * @return RDD[T] comprised of `i.i.d.` samples produced by generator.
+ */
+ @DeveloperApi
+ @Since("1.6.0")
+ def randomJavaRDD[T](
+ jsc: JavaSparkContext,
+ generator: RandomDataGenerator[T],
+ size: Long,
+ numPartitions: Int,
+ seed: Long): JavaRDD[T] = {
+ implicit val ctag: ClassTag[T] = fakeClassTag
+ val rdd = randomRDD(jsc.sc, generator, size, numPartitions, seed)
+ JavaRDD.fromRDD(rdd)
+ }
+
+ /**
+ * [[RandomRDDs#randomJavaRDD]] with the default seed.
+ */
+ @DeveloperApi
+ @Since("1.6.0")
+ def randomJavaRDD[T](
+ jsc: JavaSparkContext,
+ generator: RandomDataGenerator[T],
+ size: Long,
+ numPartitions: Int): JavaRDD[T] = {
+ randomJavaRDD(jsc, generator, size, numPartitions, Utils.random.nextLong())
+ }
+
+ /**
+ * [[RandomRDDs#randomJavaRDD]] with the default seed & numPartitions
+ */
+ @DeveloperApi
+ @Since("1.6.0")
+ def randomJavaRDD[T](
+ jsc: JavaSparkContext,
+ generator: RandomDataGenerator[T],
+ size: Long): JavaRDD[T] = {
+ randomJavaRDD(jsc, generator, size, 0);
+ }
+
// TODO Generate RDD[Vector] from multivariate distributions.
/**
diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
index 33d81b1e95..fce5f6712f 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.random;
+import java.io.Serializable;
import java.util.Arrays;
import org.apache.spark.api.java.JavaRDD;
@@ -231,4 +232,33 @@ public class JavaRandomRDDsSuite {
}
}
+ @Test
+ public void testArbitrary() {
+ long size = 10;
+ long seed = 1L;
+ int numPartitions = 0;
+ StringGenerator gen = new StringGenerator();
+ JavaRDD<String> rdd1 = randomJavaRDD(sc, gen, size);
+ JavaRDD<String> rdd2 = randomJavaRDD(sc, gen, size, numPartitions);
+ JavaRDD<String> rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed);
+ for (JavaRDD<String> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+ Assert.assertEquals(size, rdd.count());
+ Assert.assertEquals(2, rdd.first().length());
+ }
+ }
+}
+
+// This is just a test generator, it always returns a string of 42
+class StringGenerator implements RandomDataGenerator<String>, Serializable {
+ @Override
+ public String nextValue() {
+ return "42";
+ }
+ @Override
+ public StringGenerator copy() {
+ return new StringGenerator();
+ }
+ @Override
+ public void setSeed(long seed) {
+ }
}