diff options
author | Meihua Wu <meihuawu@umich.edu> | 2015-09-22 11:05:24 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2015-09-22 11:05:24 +0100 |
commit | 870b8a2edd44c9274c43ca0db4ef5b0998e16fd8 (patch) | |
tree | b66eaaafa6a8b0067ca3f1bc06c360bf174f27a7 /mllib | |
parent | 7278f792a73bbcf8d68f38dc2d87cf722693c4cf (diff) | |
download | spark-870b8a2edd44c9274c43ca0db4ef5b0998e16fd8.tar.gz spark-870b8a2edd44c9274c43ca0db4ef5b0998e16fd8.tar.bz2 spark-870b8a2edd44c9274c43ca0db4ef5b0998e16fd8.zip |
[SPARK-10706] [MLLIB] Add java wrapper for random vector rdd
Add java wrapper for random vector rdd
holdenk srowen
Author: Meihua Wu <meihuawu@umich.edu>
Closes #8841 from rotationsymmetry/SPARK-10706.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala | 42 | ||||
-rw-r--r-- | mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java | 17 |
2 files changed, 59 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index f8ff26b579..41d7c4d355 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -856,6 +856,48 @@ object RandomRDDs { } /** + * Java-friendly version of [[RandomRDDs#randomVectorRDD]]. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int, + numPartitions: Int, + seed: Long): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions, seed).toJavaRDD() + } + + /** + * [[RandomRDDs#randomJavaVectorRDD]] with the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int, + numPartitions: Int): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions).toJavaRDD() + } + + /** + * [[RandomRDDs#randomJavaVectorRDD]] with the default number of partitions and the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols).toJavaRDD() + } + + /** * Returns `numPartitions` if it is positive, or `sc.defaultParallelism` otherwise. */ private def numPartitionsOrDefault(sc: SparkContext, numPartitions: Int): Int = { diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index fce5f6712f..5728df5aee 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -246,6 +246,23 @@ public class JavaRandomRDDsSuite { Assert.assertEquals(2, rdd.first().length()); } } + + @Test + @SuppressWarnings("unchecked") + public void testRandomVectorRDD() { + UniformGenerator generator = new UniformGenerator(); + long m = 100L; + int n = 10; + int p = 2; + long seed = 1L; + JavaRDD<Vector> rdd1 = randomJavaVectorRDD(sc, generator, m, n); + JavaRDD<Vector> rdd2 = randomJavaVectorRDD(sc, generator, m, n, p); + JavaRDD<Vector> rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed); + for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + Assert.assertEquals(n, rdd.first().size()); + } + } } // This is just a test generator, it always returns a string of 42 |