aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2015-12-08 11:46:26 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-12-08 11:46:26 -0800
commit5cb4695051e3dac847b1ea14d62e54dcf672c31c (patch)
treee75ce10784a244e720049652896a70ca0a99c306 /mllib
parent4bcb894948c1b7294d84e2bf58abb1d79e6759c6 (diff)
downloadspark-5cb4695051e3dac847b1ea14d62e54dcf672c31c.tar.gz
spark-5cb4695051e3dac847b1ea14d62e54dcf672c31c.tar.bz2
spark-5cb4695051e3dac847b1ea14d62e54dcf672c31c.zip
[SPARK-11605][MLLIB] ML 1.6 QA: API: Java compatibility, docs
jira: https://issues.apache.org/jira/browse/SPARK-11605 Check Java compatibility for MLlib for this release. fix: 1. `StreamingTest.registerStream` needs java friendly interface. 2. `GradientBoostedTreesModel.computeInitialPredictionAndError` and `GradientBoostedTreesModel.updatePredictionError` has java compatibility issue. Mark them as `developerAPI`. TBD: [updated] no fix for now per discussion. `org.apache.spark.mllib.classification.LogisticRegressionModel` `public scala.Option<java.lang.Object> getThreshold();` has wrong return type for Java invocation. `SVMModel` has the similar issue. Yet adding a `scala.Option<java.util.Double> getThreshold()` would result in an overloading error due to the same function signature. And adding a new function with different name seems to be not necessary. cc jkbradley feynmanliang Author: Yuhao Yang <hhbyyh@gmail.com> Closes #10102 from hhbyyh/javaAPI.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala50
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala6
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java38
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala25
4 files changed, 94 insertions, 25 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala
index 75c6a51d09..e990fe0768 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala
@@ -17,13 +17,31 @@
package org.apache.spark.mllib.stat.test
+import scala.beans.BeanInfo
+
import org.apache.spark.Logging
import org.apache.spark.annotation.{Experimental, Since}
-import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.api.java.JavaDStream
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.StatCounter
/**
+ * Class that represents the group and value of a sample.
+ *
+ * @param isExperiment if the sample is of the experiment group.
+ * @param value numeric value of the observation.
+ */
+@Since("1.6.0")
+@BeanInfo
+case class BinarySample @Since("1.6.0") (
+ @Since("1.6.0") isExperiment: Boolean,
+ @Since("1.6.0") value: Double) {
+ override def toString: String = {
+ s"($isExperiment, $value)"
+ }
+}
+
+/**
* :: Experimental ::
* Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The
* Boolean identifies which sample each observation comes from, and the Double is the numeric value
@@ -83,13 +101,13 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable {
/**
* Register a [[DStream]] of values for significance testing.
*
- * @param data stream of (key,value) pairs where the key denotes group membership (true =
- * experiment, false = control) and the value is the numerical metric to test for
- * significance
+ * @param data stream of BinarySample(key,value) pairs where the key denotes group membership
+ * (true = experiment, false = control) and the value is the numerical metric to
+ * test for significance
* @return stream of significance testing results
*/
@Since("1.6.0")
- def registerStream(data: DStream[(Boolean, Double)]): DStream[StreamingTestResult] = {
+ def registerStream(data: DStream[BinarySample]): DStream[StreamingTestResult] = {
val dataAfterPeacePeriod = dropPeacePeriod(data)
val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod)
val pairedSummaries = pairSummaries(summarizedData)
@@ -97,9 +115,22 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable {
testMethod.doTest(pairedSummaries)
}
+ /**
+ * Register a [[JavaDStream]] of values for significance testing.
+ *
+ * @param data stream of BinarySample(isExperiment,value) pairs where the isExperiment denotes
+ * group (true = experiment, false = control) and the value is the numerical metric
+ * to test for significance
+ * @return stream of significance testing results
+ */
+ @Since("1.6.0")
+ def registerStream(data: JavaDStream[BinarySample]): JavaDStream[StreamingTestResult] = {
+ JavaDStream.fromDStream(registerStream(data.dstream))
+ }
+
/** Drop all batches inside the peace period. */
private[stat] def dropPeacePeriod(
- data: DStream[(Boolean, Double)]): DStream[(Boolean, Double)] = {
+ data: DStream[BinarySample]): DStream[BinarySample] = {
data.transform { (rdd, time) =>
if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) {
rdd
@@ -111,9 +142,10 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable {
/** Compute summary statistics over each key and the specified test window size. */
private[stat] def summarizeByKeyAndWindow(
- data: DStream[(Boolean, Double)]): DStream[(Boolean, StatCounter)] = {
+ data: DStream[BinarySample]): DStream[(Boolean, StatCounter)] = {
+ val categoryValuePair = data.map(sample => (sample.isExperiment, sample.value))
if (this.windowSize == 0) {
- data.updateStateByKey[StatCounter](
+ categoryValuePair.updateStateByKey[StatCounter](
(newValues: Seq[Double], oldSummary: Option[StatCounter]) => {
val newSummary = oldSummary.getOrElse(new StatCounter())
newSummary.merge(newValues)
@@ -121,7 +153,7 @@ class StreamingTest @Since("1.6.0") () extends Logging with Serializable {
})
} else {
val windowDuration = data.slideDuration * this.windowSize
- data
+ categoryValuePair
.groupByKeyAndWindow(windowDuration)
.mapValues { values =>
val summary = new StatCounter()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index 3f427f0be3..feabcee24f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -25,7 +25,7 @@ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.{Logging, SparkContext}
-import org.apache.spark.annotation.Since
+import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
@@ -186,6 +186,7 @@ class GradientBoostedTreesModel @Since("1.2.0") (
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
/**
+ * :: DeveloperApi ::
* Compute the initial predictions and errors for a dataset for the first
* iteration of gradient boosting.
* @param data: training data.
@@ -196,6 +197,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
* corresponding to every sample.
*/
@Since("1.4.0")
+ @DeveloperApi
def computeInitialPredictionAndError(
data: RDD[LabeledPoint],
initTreeWeight: Double,
@@ -209,6 +211,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
}
/**
+ * :: DeveloperApi ::
* Update a zipped predictionError RDD
* (as obtained with computeInitialPredictionAndError)
* @param data: training data.
@@ -220,6 +223,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
* corresponding to each sample.
*/
@Since("1.4.0")
+ @DeveloperApi
def updatePredictionError(
data: RDD[LabeledPoint],
predictionAndError: RDD[(Double, Double)],
diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
index 4795809e47..66b2ceacb0 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java
@@ -18,34 +18,49 @@
package org.apache.spark.mllib.stat;
import java.io.Serializable;
-
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
+import static org.apache.spark.streaming.JavaTestUtils.*;
import static org.junit.Assert.assertEquals;
+import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.stat.test.BinarySample;
import org.apache.spark.mllib.stat.test.ChiSqTestResult;
import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult;
+import org.apache.spark.mllib.stat.test.StreamingTest;
+import org.apache.spark.streaming.Duration;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
public class JavaStatisticsSuite implements Serializable {
private transient JavaSparkContext sc;
+ private transient JavaStreamingContext ssc;
@Before
public void setUp() {
- sc = new JavaSparkContext("local", "JavaStatistics");
+ SparkConf conf = new SparkConf()
+ .setMaster("local[2]")
+ .setAppName("JavaStatistics")
+ .set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
+ sc = new JavaSparkContext(conf);
+ ssc = new JavaStreamingContext(sc, new Duration(1000));
+ ssc.checkpoint("checkpoint");
}
@After
public void tearDown() {
- sc.stop();
+ ssc.stop();
+ ssc = null;
sc = null;
}
@@ -76,4 +91,21 @@ public class JavaStatisticsSuite implements Serializable {
new LabeledPoint(0.0, Vectors.dense(2.4, 8.1))));
ChiSqTestResult[] testResults = Statistics.chiSqTest(data);
}
+
+ @Test
+ public void streamingTest() {
+ List<BinarySample> trainingBatch = Arrays.asList(
+ new BinarySample(true, 1.0),
+ new BinarySample(false, 2.0));
+ JavaDStream<BinarySample> training =
+ attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2);
+ int numBatches = 2;
+ StreamingTest model = new StreamingTest()
+ .setWindowSize(0)
+ .setPeacePeriod(0)
+ .setTestMethod("welch");
+ model.registerStream(training);
+ attachTestOutputStream(training);
+ runStreams(ssc, numBatches, numBatches);
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala
index d3e9ef4ff0..3c657c8cfe 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala
@@ -18,7 +18,8 @@
package org.apache.spark.mllib.stat
import org.apache.spark.SparkFunSuite
-import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, WelchTTest}
+import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest,
+ WelchTTest, BinarySample}
import org.apache.spark.streaming.TestSuiteBase
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.util.StatCounter
@@ -48,7 +49,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
// setup and run the model
val ssc = setupStreams(
- input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
+ input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream))
val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
assert(outputBatches.flatten.forall(res =>
@@ -75,7 +76,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
// setup and run the model
val ssc = setupStreams(
- input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
+ input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream))
val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
assert(outputBatches.flatten.forall(res =>
@@ -102,7 +103,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
// setup and run the model
val ssc = setupStreams(
- input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
+ input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream))
val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
@@ -130,7 +131,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
// setup and run the model
val ssc = setupStreams(
- input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
+ input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream))
val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
assert(outputBatches.flatten.forall(res =>
@@ -157,7 +158,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
// setup and run the model
val ssc = setupStreams(
input,
- (inputDStream: DStream[(Boolean, Double)]) => model.summarizeByKeyAndWindow(inputDStream))
+ (inputDStream: DStream[BinarySample]) => model.summarizeByKeyAndWindow(inputDStream))
val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, numBatches)
val outputCounts = outputBatches.flatten.map(_._2.count)
@@ -190,7 +191,7 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
// setup and run the model
val ssc = setupStreams(
- input, (inputDStream: DStream[(Boolean, Double)]) => model.dropPeacePeriod(inputDStream))
+ input, (inputDStream: DStream[BinarySample]) => model.dropPeacePeriod(inputDStream))
val outputBatches = runStreams[(Boolean, Double)](ssc, numBatches, numBatches)
assert(outputBatches.flatten.length == (numBatches - peacePeriod) * pointsPerBatch)
@@ -210,11 +211,11 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
.setPeacePeriod(0)
val input = generateTestData(numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42)
- .map(batch => batch.filter(_._1)) // only keep one test group
+ .map(batch => batch.filter(_.isExperiment)) // only keep one test group
// setup and run the model
val ssc = setupStreams(
- input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream))
+ input, (inputDStream: DStream[BinarySample]) => model.registerStream(inputDStream))
val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches)
assert(outputBatches.flatten.forall(result => (result.pValue - 1.0).abs < 0.001))
@@ -228,13 +229,13 @@ class StreamingTestSuite extends SparkFunSuite with TestSuiteBase {
stdevA: Double,
meanB: Double,
stdevB: Double,
- seed: Int): (IndexedSeq[IndexedSeq[(Boolean, Double)]]) = {
+ seed: Int): (IndexedSeq[IndexedSeq[BinarySample]]) = {
val rand = new XORShiftRandom(seed)
val numTrues = pointsPerBatch / 2
val data = (0 until numBatches).map { i =>
- (0 until numTrues).map { idx => (true, meanA + stdevA * rand.nextGaussian())} ++
+ (0 until numTrues).map { idx => BinarySample(true, meanA + stdevA * rand.nextGaussian())} ++
(pointsPerBatch / 2 until pointsPerBatch).map { idx =>
- (false, meanB + stdevB * rand.nextGaussian())
+ BinarySample(false, meanB + stdevB * rand.nextGaussian())
}
}