aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org/apache
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2015-12-08 11:46:26 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-12-08 11:46:26 -0800
commit5cb4695051e3dac847b1ea14d62e54dcf672c31c (patch)
treee75ce10784a244e720049652896a70ca0a99c306 /mllib/src/test/java/org/apache
parent4bcb894948c1b7294d84e2bf58abb1d79e6759c6 (diff)
downloadspark-5cb4695051e3dac847b1ea14d62e54dcf672c31c.tar.gz
spark-5cb4695051e3dac847b1ea14d62e54dcf672c31c.tar.bz2
spark-5cb4695051e3dac847b1ea14d62e54dcf672c31c.zip
[SPARK-11605][MLLIB] ML 1.6 QA: API: Java compatibility, docs
jira: https://issues.apache.org/jira/browse/SPARK-11605 Check Java compatibility for MLlib for this release. fix: 1. `StreamingTest.registerStream` needs java friendly interface. 2. `GradientBoostedTreesModel.computeInitialPredictionAndError` and `GradientBoostedTreesModel.updatePredictionError` has java compatibility issue. Mark them as `developerAPI`. TBD: [updated] no fix for now per discussion. `org.apache.spark.mllib.classification.LogisticRegressionModel` `public scala.Option<java.lang.Object> getThreshold();` has wrong return type for Java invocation. `SVMModel` has the similar issue. Yet adding a `scala.Option<java.util.Double> getThreshold()` would result in an overloading error due to the same function signature. And adding a new function with different name seems to be not necessary. cc jkbradley feynmanliang Author: Yuhao Yang <hhbyyh@gmail.com> Closes #10102 from hhbyyh/javaAPI.
Diffstat (limited to 'mllib/src/test/java/org/apache')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java38
1 files changed, 35 insertions, 3 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
index 4795809e47..66b2ceacb0 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
@@ -18,34 +18,49 @@
package org.apache.spark.mllib.stat;
import java.io.Serializable;
-
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
+import static org.apache.spark.streaming.JavaTestUtils.*;
import static org.junit.Assert.assertEquals;
+import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.stat.test.BinarySample;
import org.apache.spark.mllib.stat.test.ChiSqTestResult;
import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult;
+import org.apache.spark.mllib.stat.test.StreamingTest;
+import org.apache.spark.streaming.Duration;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
public class JavaStatisticsSuite implements Serializable {
private transient JavaSparkContext sc;
+ private transient JavaStreamingContext ssc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaStatistics");
+ SparkConf conf = new SparkConf()
+ .setMaster("local[2]")
+ .setAppName("JavaStatistics")
+ .set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
+ sc = new JavaSparkContext(conf);
+ ssc = new JavaStreamingContext(sc, new Duration(1000));
+ ssc.checkpoint("checkpoint");
}
@After
public void tearDown() {
- sc.stop();
+ ssc.stop();
+ ssc = null;
sc = null;
}
@@ -76,4 +91,21 @@ public class JavaStatisticsSuite implements Serializable {
new LabeledPoint(0.0, Vectors.dense(2.4, 8.1))));
ChiSqTestResult[] testResults = Statistics.chiSqTest(data);
}
+
+ @Test
+ public void streamingTest() {
+ List<BinarySample> trainingBatch = Arrays.asList(
+ new BinarySample(true, 1.0),
+ new BinarySample(false, 2.0));
+ JavaDStream<BinarySample> training =
+ attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2);
+ int numBatches = 2;
+ StreamingTest model = new StreamingTest()
+ .setWindowSize(0)
+ .setPeacePeriod(0)
+ .setTestMethod("welch");
+ model.registerStream(training);
+ attachTestOutputStream(training);
+ runStreams(ssc, numBatches, numBatches);
+ }
}