aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-08-19 16:06:48 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-19 16:06:48 -0700
commit825d4fe47b9c4d48de88622dd48dcf83beb8b80a (patch)
treed51775e9f88bff51458e57a5ec16de6e0b93b91a /mllib/src/test/java/org
parentd7e80c2597d4a9cae2e0cb35a86f7889323f4cbb (diff)
downloadspark-825d4fe47b9c4d48de88622dd48dcf83beb8b80a.tar.gz
spark-825d4fe47b9c4d48de88622dd48dcf83beb8b80a.tar.bz2
spark-825d4fe47b9c4d48de88622dd48dcf83beb8b80a.zip
[SPARK-3136][MLLIB] Create Java-friendly methods in RandomRDDs
Though we don't use default argument for methods in RandomRDDs, it is still not easy for Java users to use because the output type is either `RDD[Double]` or `RDD[Vector]`. Java users should expect `JavaDoubleRDD` and `JavaRDD[Vector]`, respectively. We should create dedicated methods for Java users, and allow default arguments in Scala methods in RandomRDDs, to make life easier for both Java and Scala users. This PR also contains documentation for random data generation. brkyvz Author: Xiangrui Meng <meng@databricks.com> Closes #2041 from mengxr/stat-doc and squashes the following commits: fc5eedf [Xiangrui Meng] add missing comma ffde810 [Xiangrui Meng] address comments aef6d07 [Xiangrui Meng] add doc for random data generation b99d94b [Xiangrui Meng] add java-friendly methods to RandomRDDs
Diffstat (limited to 'mllib/src/test/java/org')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java134
1 files changed, 134 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
new file mode 100644
index 0000000000..a725736ca1
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.random;
+
+import com.google.common.collect.Lists;
+import org.apache.spark.api.java.JavaRDD;
+import org.junit.Assert;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaDoubleRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
+import static org.apache.spark.mllib.random.RandomRDDs.*;
+
+public class JavaRandomRDDsSuite {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaRandomRDDsSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void testUniformRDD() {
+ long m = 1000L;
+ int p = 2;
+ long seed = 1L;
+ JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m);
+ JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p);
+ JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed);
+ for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
+ Assert.assertEquals(m, rdd.count());
+ }
+ }
+
+ @Test
+ public void testNormalRDD() {
+ long m = 1000L;
+ int p = 2;
+ long seed = 1L;
+ JavaDoubleRDD rdd1 = normalJavaRDD(sc, m);
+ JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p);
+ JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed);
+ for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
+ Assert.assertEquals(m, rdd.count());
+ }
+ }
+
+ @Test
+ public void testPoissonRDD() {
+ double mean = 2.0;
+ long m = 1000L;
+ int p = 2;
+ long seed = 1L;
+ JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m);
+ JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p);
+ JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed);
+ for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
+ Assert.assertEquals(m, rdd.count());
+ }
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testUniformVectorRDD() {
+ long m = 100L;
+ int n = 10;
+ int p = 2;
+ long seed = 1L;
+ JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(sc, m, n);
+ JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(sc, m, n, p);
+ JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed);
+ for (JavaRDD<Vector> rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
+ Assert.assertEquals(m, rdd.count());
+ Assert.assertEquals(n, rdd.first().size());
+ }
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testNormalVectorRDD() {
+ long m = 100L;
+ int n = 10;
+ int p = 2;
+ long seed = 1L;
+ JavaRDD<Vector> rdd1 = normalJavaVectorRDD(sc, m, n);
+ JavaRDD<Vector> rdd2 = normalJavaVectorRDD(sc, m, n, p);
+ JavaRDD<Vector> rdd3 = normalJavaVectorRDD(sc, m, n, p, seed);
+ for (JavaRDD<Vector> rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
+ Assert.assertEquals(m, rdd.count());
+ Assert.assertEquals(n, rdd.first().size());
+ }
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testPoissonVectorRDD() {
+ double mean = 2.0;
+ long m = 100L;
+ int n = 10;
+ int p = 2;
+ long seed = 1L;
+ JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(sc, mean, m, n);
+ JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p);
+ JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed);
+ for (JavaRDD<Vector> rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) {
+ Assert.assertEquals(m, rdd.count());
+ Assert.assertEquals(n, rdd.first().size());
+ }
+ }
+}