aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-08-13 13:42:35 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-13 13:42:35 -0700
commit864de8eaf4b6ad5c9099f6f29e251c56b029f631 (patch)
tree02817e0b6860c2899ff366af83757d9dd4d014df /mllib/src/test/java
parent8815ba2f674dbb18eb499216df9942b411e10daa (diff)
downloadspark-864de8eaf4b6ad5c9099f6f29e251c56b029f631.tar.gz
spark-864de8eaf4b6ad5c9099f6f29e251c56b029f631.tar.bz2
spark-864de8eaf4b6ad5c9099f6f29e251c56b029f631.zip
[SPARK-9661] [MLLIB] [ML] Java compatibility
I skimmed through the docs for various instance of Object and replaced them with Java compaible versions of the same. 1. Some methods in LDAModel. 2. runMiniBatchSGD 3. kolmogorovSmirnovTest Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #8126 from MechCoder/java_incop.
Diffstat (limited to 'mllib/src/test/java')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java24
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java22
2 files changed, 46 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index d272a42c85..427be9430d 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -124,6 +124,10 @@ public class JavaLDASuite implements Serializable {
}
});
assertEquals(topicDistributions.count(), nonEmptyCorpus.count());
+
+ // Check: javaTopTopicsPerDocuments
+ JavaRDD<scala.Tuple3<java.lang.Long, int[], java.lang.Double[]>> topTopics =
+ model.javaTopTopicsPerDocument(3);
}
@Test
@@ -160,11 +164,31 @@ public class JavaLDASuite implements Serializable {
assertEquals(roundedLocalTopicSummary.length, k);
}
+ @Test
+ public void localLdaMethods() {
+ JavaRDD<Tuple2<Long, Vector>> docs = sc.parallelize(toyData, 2);
+ JavaPairRDD<Long, Vector> pairedDocs = JavaPairRDD.fromJavaRDD(docs);
+
+ // check: topicDistributions
+ assertEquals(toyModel.topicDistributions(pairedDocs).count(), pairedDocs.count());
+
+ // check: logPerplexity
+ double logPerplexity = toyModel.logPerplexity(pairedDocs);
+
+ // check: logLikelihood.
+ ArrayList<Tuple2<Long, Vector>> docsSingleWord = new ArrayList<Tuple2<Long, Vector>>();
+ docsSingleWord.add(new Tuple2<Long, Vector>(Long.valueOf(0), Vectors.dense(1.0, 0.0, 0.0)));
+ JavaPairRDD<Long, Vector> single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord));
+ double logLikelihood = toyModel.logLikelihood(single);
+ }
+
private static int tinyK = LDASuite$.MODULE$.tinyK();
private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize();
private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics();
private static Tuple2<int[], double[]>[] tinyTopicDescription =
LDASuite$.MODULE$.tinyTopicDescription();
private JavaPairRDD<Long, Vector> corpus;
+ private LocalLDAModel toyModel = LDASuite$.MODULE$.toyModel();
+ private ArrayList<Tuple2<Long, Vector>> toyData = LDASuite$.MODULE$.javaToyData();
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
index 62f7f26b7c..eb4e369862 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
@@ -27,7 +27,12 @@ import org.junit.Test;
import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.stat.test.ChiSqTestResult;
+import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult;
public class JavaStatisticsSuite implements Serializable {
private transient JavaSparkContext sc;
@@ -53,4 +58,21 @@ public class JavaStatisticsSuite implements Serializable {
// Check default method
assertEquals(corr1, corr2);
}
+
+ @Test
+ public void kolmogorovSmirnovTest() {
+ JavaDoubleRDD data = sc.parallelizeDoubles(Lists.newArrayList(0.2, 1.0, -1.0, 2.0));
+ KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm");
+ KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest(
+ data, "norm", 0.0, 1.0);
+ }
+
+ @Test
+ public void chiSqTest() {
+ JavaRDD<LabeledPoint> data = sc.parallelize(Lists.newArrayList(
+ new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)),
+ new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)),
+ new LabeledPoint(0.0, Vectors.dense(2.4, 8.1))));
+ ChiSqTestResult[] testResults = Statistics.chiSqTest(data);
+ }
}