diff options
author | MechCoder <manojkumarsivaraj334@gmail.com> | 2015-08-13 13:42:35 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-08-13 13:42:35 -0700 |
commit | 864de8eaf4b6ad5c9099f6f29e251c56b029f631 (patch) | |
tree | 02817e0b6860c2899ff366af83757d9dd4d014df /mllib/src/test/java/org/apache | |
parent | 8815ba2f674dbb18eb499216df9942b411e10daa (diff) | |
download | spark-864de8eaf4b6ad5c9099f6f29e251c56b029f631.tar.gz spark-864de8eaf4b6ad5c9099f6f29e251c56b029f631.tar.bz2 spark-864de8eaf4b6ad5c9099f6f29e251c56b029f631.zip |
[SPARK-9661] [MLLIB] [ML] Java compatibility
I skimmed through the docs for various instance of Object and replaced them with Java compaible versions of the same.
1. Some methods in LDAModel.
2. runMiniBatchSGD
3. kolmogorovSmirnovTest
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Closes #8126 from MechCoder/java_incop.
Diffstat (limited to 'mllib/src/test/java/org/apache')
-rw-r--r-- | mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java | 24 | ||||
-rw-r--r-- | mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java | 22 |
2 files changed, 46 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index d272a42c85..427be9430d 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -124,6 +124,10 @@ public class JavaLDASuite implements Serializable { } }); assertEquals(topicDistributions.count(), nonEmptyCorpus.count()); + + // Check: javaTopTopicsPerDocuments + JavaRDD<scala.Tuple3<java.lang.Long, int[], java.lang.Double[]>> topTopics = + model.javaTopTopicsPerDocument(3); } @Test @@ -160,11 +164,31 @@ public class JavaLDASuite implements Serializable { assertEquals(roundedLocalTopicSummary.length, k); } + @Test + public void localLdaMethods() { + JavaRDD<Tuple2<Long, Vector>> docs = sc.parallelize(toyData, 2); + JavaPairRDD<Long, Vector> pairedDocs = JavaPairRDD.fromJavaRDD(docs); + + // check: topicDistributions + assertEquals(toyModel.topicDistributions(pairedDocs).count(), pairedDocs.count()); + + // check: logPerplexity + double logPerplexity = toyModel.logPerplexity(pairedDocs); + + // check: logLikelihood. + ArrayList<Tuple2<Long, Vector>> docsSingleWord = new ArrayList<Tuple2<Long, Vector>>(); + docsSingleWord.add(new Tuple2<Long, Vector>(Long.valueOf(0), Vectors.dense(1.0, 0.0, 0.0))); + JavaPairRDD<Long, Vector> single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord)); + double logLikelihood = toyModel.logLikelihood(single); + } + private static int tinyK = LDASuite$.MODULE$.tinyK(); private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize(); private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics(); private static Tuple2<int[], double[]>[] tinyTopicDescription = LDASuite$.MODULE$.tinyTopicDescription(); private JavaPairRDD<Long, Vector> corpus; + private LocalLDAModel toyModel = LDASuite$.MODULE$.toyModel(); + private ArrayList<Tuple2<Long, Vector>> toyData = LDASuite$.MODULE$.javaToyData(); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index 62f7f26b7c..eb4e369862 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -27,7 +27,12 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.stat.test.ChiSqTestResult; +import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; public class JavaStatisticsSuite implements Serializable { private transient JavaSparkContext sc; @@ -53,4 +58,21 @@ public class JavaStatisticsSuite implements Serializable { // Check default method assertEquals(corr1, corr2); } + + @Test + public void kolmogorovSmirnovTest() { + JavaDoubleRDD data = sc.parallelizeDoubles(Lists.newArrayList(0.2, 1.0, -1.0, 2.0)); + KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm"); + KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest( + data, "norm", 0.0, 1.0); + } + + @Test + public void chiSqTest() { + JavaRDD<LabeledPoint> data = sc.parallelize(Lists.newArrayList( + new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)), + new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)), + new LabeledPoint(0.0, Vectors.dense(2.4, 8.1)))); + ChiSqTestResult[] testResults = Statistics.chiSqTest(data); + } } |