aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authormartinzapletal <zapletal-martin@email.cz>2015-01-31 00:46:02 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-31 00:46:02 -0800
commit34250a613cee39b03f05f2d54ae37029abd8a502 (patch)
tree0b60c4e68c54d0642a7efc7f0c454e0a9f5d9c31 /mllib
parent636408311deeebd77fb83d2249e0afad1a1ba149 (diff)
downloadspark-34250a613cee39b03f05f2d54ae37029abd8a502.tar.gz
spark-34250a613cee39b03f05f2d54ae37029abd8a502.tar.bz2
spark-34250a613cee39b03f05f2d54ae37029abd8a502.zip
[MLLIB][SPARK-3278] Monotone (Isotonic) regression using parallel pool adjacent violators algorithm
This PR introduces an API for Isotonic regression and one algorithm implementing it, Pool adjacent violators. The Isotonic regression problem is sufficiently described in [Floudas, Pardalos, Encyclopedia of Optimization](http://books.google.co.uk/books?id=gtoTkL7heS0C&pg=RA2-PA87&lpg=RA2-PA87&dq=pooled+adjacent+violators+code&source=bl&ots=ZzQbZXVJnn&sig=reH_hBV6yIb9BeZNTF9092vD8PY&hl=en&sa=X&ei=WmF2VLiOIZLO7Qa-t4Bo&ved=0CD8Q6AEwBA#v=onepage&q&f=false), [Wikipedia](http://en.wikipedia.org/wiki/Isotonic_regression) or [Stat Wiki](http://stat.wikia.com/wiki/Isotonic_regression). Pool adjacent violators was introduced by M. Ayer et al. in 1955. A history and development of isotonic regression algorithms is in [Leeuw, Hornik, Mair, Isotone Optimization in R: Pool-Adjacent-Violators Algorithm (PAVA) and Active Set Methods](http://www.jstatsoft.org/v32/i05/paper) and list of available algorithms including their complexity is listed in [Stout, Fastest Isotonic Regression Algorithms](http://web.eecs.umich.edu/~qstout/IsoRegAlg_140812.pdf). An approach to parallelize the computation of PAV was presented in [Kearsley, Tapia, Trosset, An Approach to Parallelizing Isotonic Regression](http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf). The implemented Pool adjacent violators algorithm is based on [Floudas, Pardalos, Encyclopedia of Optimization](http://books.google.co.uk/books?id=gtoTkL7heS0C&pg=RA2-PA87&lpg=RA2-PA87&dq=pooled+adjacent+violators+code&source=bl&ots=ZzQbZXVJnn&sig=reH_hBV6yIb9BeZNTF9092vD8PY&hl=en&sa=X&ei=WmF2VLiOIZLO7Qa-t4Bo&ved=0CD8Q6AEwBA#v=onepage&q&f=false) (Chapter Isotonic regression problems, p. 86) and [Leeuw, Hornik, Mair, Isotone Optimization in R: Pool-Adjacent-Violators Algorithm (PAVA) and Active Set Methods](http://www.jstatsoft.org/v32/i05/paper), also nicely formulated in [Tibshirani, Hoefling, Tibshirani, Nearly-Isotonic Regression](http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf). Implementation itself inspired by R implementations [Klaus, Strimmer, 2008, fdrtool: Estimation of (Local) False Discovery Rates and Higher Criticism](http://cran.r-project.org/web/packages/fdrtool/index.html) and [R Development Core Team, stats, 2009](https://github.com/lgautier/R-3-0-branch-alt/blob/master/src/library/stats/R/isoreg.R). I ran tests with both these libraries and confirmed they yield the same results. More R implementations referenced in aforementioned [Leeuw, Hornik, Mair, Isotone Optimization in R: Pool-Adjacent-Violators Algorithm (PAVA) and Active Set Methods](http://www.jstatsoft.org/v32/i05/paper). The implementation is also inspired and cross checked with other implementations: [Ted Harding, 2007](https://stat.ethz.ch/pipermail/r-help/2007-March/127981.html), [scikit-learn](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/_isotonic.pyx), [Andrew Tulloch, 2014, Julia](https://github.com/ajtulloch/Isotonic.jl/blob/master/src/pooled_pava.jl), [Andrew Tulloch, 2014, c++](https://gist.github.com/ajtulloch/9499872), described in [Andrew Tulloch, Speeding up isotonic regression in scikit-learn by 5,000x](http://tullo.ch/articles/speeding-up-isotonic-regression/), [Fabian Pedregosa, 2012](https://gist.github.com/fabianp/3081831), [Sreangsu Acharyya. libpav](https://bitbucket.org/sreangsu/libpav/src/f744bc1b0fea257f0cacaead1c922eab201ba91b/src/pav.h?at=default) and [Gustav Larsson](https://gist.github.com/gustavla/9499068). Author: martinzapletal <zapletal-martin@email.cz> Author: Xiangrui Meng <meng@databricks.com> Author: Martin Zapletal <zapletal-martin@email.cz> Closes #3519 from zapletal-martin/SPARK-3278 and squashes the following commits: 5a54ea4 [Martin Zapletal] Merge pull request #2 from mengxr/isotonic-fix-java 37ba24e [Xiangrui Meng] fix java tests e3c0e44 [martinzapletal] Merge remote-tracking branch 'origin/SPARK-3278' into SPARK-3278 d8feb82 [martinzapletal] Merge remote-tracking branch 'upstream/master' into SPARK-3278 ded071c [Martin Zapletal] Merge pull request #1 from mengxr/SPARK-3278 4dfe136 [Xiangrui Meng] add cache back 0b35c15 [Xiangrui Meng] compress pools and update tests 35d044e [Xiangrui Meng] update paraPAVA 077606b [Xiangrui Meng] minor 05422a8 [Xiangrui Meng] add unit test for model construction 5925113 [Xiangrui Meng] Merge remote-tracking branch 'zapletal-martin/SPARK-3278' into SPARK-3278 80c6681 [Xiangrui Meng] update IRModel 3da56e5 [martinzapletal] SPARK-3278 fixed indentation error 75eac55 [martinzapletal] Merge remote-tracking branch 'upstream/master' into SPARK-3278 88eb4e2 [martinzapletal] SPARK-3278 changes after PR comments https://github.com/apache/spark/pull/3519. Isotonic parameter removed from algorithm, defined behaviour for multiple data points with the same feature value, added tests to verify it e60a34f [martinzapletal] SPARK-3278 changes after PR comments https://github.com/apache/spark/pull/3519. Styling and comment fixes. d93c8f9 [martinzapletal] SPARK-3278 changes after PR comments https://github.com/apache/spark/pull/3519. Change to IsotonicRegression api. Isotonic parameter now follows api of other mllib algorithms 1fff77d [martinzapletal] SPARK-3278 changes after PR comments https://github.com/apache/spark/pull/3519. Java api changes, test refactoring, comments and citations, isotonic regression model validations, linear interpolation for predictions 12151e6 [martinzapletal] Merge remote-tracking branch 'upstream/master' into SPARK-3278 7aca4cc [martinzapletal] SPARK-3278 comment spelling 9ae9d53 [martinzapletal] SPARK-3278 changes after PR feedback https://github.com/apache/spark/pull/3519. Binary search used for isotonic regression model predictions fad4bf9 [martinzapletal] SPARK-3278 changes after PR comments https://github.com/apache/spark/pull/3519 ce0e30c [martinzapletal] SPARK-3278 readability refactoring f90c8c7 [martinzapletal] Merge remote-tracking branch 'upstream/master' into SPARK-3278 0d14bd3 [martinzapletal] SPARK-3278 changed Java api to match Scala api's (Double, Double, Double) 3c2954b [martinzapletal] SPARK-3278 Isotonic regression java api 45aa7e8 [martinzapletal] SPARK-3278 Isotonic regression java api e9b3323 [martinzapletal] Merge branch 'SPARK-3278-weightedLabeledPoint' into SPARK-3278 823d803 [martinzapletal] Merge remote-tracking branch 'upstream/master' into SPARK-3278 941fd1f [martinzapletal] SPARK-3278 Isotonic regression java api a24e29f [martinzapletal] SPARK-3278 refactored weightedlabeledpoint to (double, double, double) and updated api deb0f17 [martinzapletal] SPARK-3278 refactored weightedlabeledpoint to (double, double, double) and updated api 8cefd18 [martinzapletal] Merge remote-tracking branch 'upstream/master' into SPARK-3278-weightedLabeledPoint cab5a46 [martinzapletal] SPARK-3278 PR 3519 refactoring WeightedLabeledPoint to tuple as per comments b8b1620 [martinzapletal] Removed WeightedLabeledPoint. Replaced by tuple of doubles 34760d5 [martinzapletal] Removed WeightedLabeledPoint. Replaced by tuple of doubles 089bf86 [martinzapletal] Removed MonotonicityConstraint, Isotonic and Antitonic constraints. Replced by simple boolean c06f88c [martinzapletal] Merge remote-tracking branch 'upstream/master' into SPARK-3278 6046550 [martinzapletal] SPARK-3278 scalastyle errors resolved 8f5daf9 [martinzapletal] SPARK-3278 added comments and cleaned up api to consistently handle weights 629a1ce [martinzapletal] SPARK-3278 added isotonic regression for weighted data. Added tests for Java api 05d9048 [martinzapletal] SPARK-3278 isotonic regression refactoring and api changes 961aa05 [martinzapletal] Merge remote-tracking branch 'upstream/master' into SPARK-3278 3de71d0 [martinzapletal] SPARK-3278 added initial version of Isotonic regression algorithm including proposed API
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala304
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java89
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala241
3 files changed, 634 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
new file mode 100644
index 0000000000..5ed6477bae
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala
@@ -0,0 +1,304 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.regression
+
+import java.io.Serializable
+import java.lang.{Double => JDouble}
+import java.util.Arrays.binarySearch
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD}
+import org.apache.spark.rdd.RDD
+
+/**
+ * Regression model for isotonic regression.
+ *
+ * @param boundaries Array of boundaries for which predictions are known.
+ * Boundaries must be sorted in increasing order.
+ * @param predictions Array of predictions associated to the boundaries at the same index.
+ * Results of isotonic regression and therefore monotone.
+ * @param isotonic indicates whether this is isotonic or antitonic.
+ */
+class IsotonicRegressionModel (
+ val boundaries: Array[Double],
+ val predictions: Array[Double],
+ val isotonic: Boolean) extends Serializable {
+
+ private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse
+
+ require(boundaries.length == predictions.length)
+ assertOrdered(boundaries)
+ assertOrdered(predictions)(predictionOrd)
+
+ /** Asserts the input array is monotone with the given ordering. */
+ private def assertOrdered(xs: Array[Double])(implicit ord: Ordering[Double]): Unit = {
+ var i = 1
+ while (i < xs.length) {
+ require(ord.compare(xs(i - 1), xs(i)) <= 0,
+ s"Elements (${xs(i - 1)}, ${xs(i)}) are not ordered.")
+ i += 1
+ }
+ }
+
+ /**
+ * Predict labels for provided features.
+ * Using a piecewise linear function.
+ *
+ * @param testData Features to be labeled.
+ * @return Predicted labels.
+ */
+ def predict(testData: RDD[Double]): RDD[Double] = {
+ testData.map(predict)
+ }
+
+ /**
+ * Predict labels for provided features.
+ * Using a piecewise linear function.
+ *
+ * @param testData Features to be labeled.
+ * @return Predicted labels.
+ */
+ def predict(testData: JavaDoubleRDD): JavaDoubleRDD = {
+ JavaDoubleRDD.fromRDD(predict(testData.rdd.retag.asInstanceOf[RDD[Double]]))
+ }
+
+ /**
+ * Predict a single label.
+ * Using a piecewise linear function.
+ *
+ * @param testData Feature to be labeled.
+ * @return Predicted label.
+ * 1) If testData exactly matches a boundary then associated prediction is returned.
+ * In case there are multiple predictions with the same boundary then one of them
+ * is returned. Which one is undefined (same as java.util.Arrays.binarySearch).
+ * 2) If testData is lower or higher than all boundaries then first or last prediction
+ * is returned respectively. In case there are multiple predictions with the same
+ * boundary then the lowest or highest is returned respectively.
+ * 3) If testData falls between two values in boundary array then prediction is treated
+ * as piecewise linear function and interpolated value is returned. In case there are
+ * multiple values with the same boundary then the same rules as in 2) are used.
+ */
+ def predict(testData: Double): Double = {
+
+ def linearInterpolation(x1: Double, y1: Double, x2: Double, y2: Double, x: Double): Double = {
+ y1 + (y2 - y1) * (x - x1) / (x2 - x1)
+ }
+
+ val foundIndex = binarySearch(boundaries, testData)
+ val insertIndex = -foundIndex - 1
+
+ // Find if the index was lower than all values,
+ // higher than all values, in between two values or exact match.
+ if (insertIndex == 0) {
+ predictions.head
+ } else if (insertIndex == boundaries.length){
+ predictions.last
+ } else if (foundIndex < 0) {
+ linearInterpolation(
+ boundaries(insertIndex - 1),
+ predictions(insertIndex - 1),
+ boundaries(insertIndex),
+ predictions(insertIndex),
+ testData)
+ } else {
+ predictions(foundIndex)
+ }
+ }
+}
+
+/**
+ * Isotonic regression.
+ * Currently implemented using parallelized pool adjacent violators algorithm.
+ * Only univariate (single feature) algorithm supported.
+ *
+ * Sequential PAV implementation based on:
+ * Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani.
+ * "Nearly-isotonic regression." Technometrics 53.1 (2011): 54-61.
+ * Available from http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf
+ *
+ * Sequential PAV parallelization based on:
+ * Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset.
+ * "An approach to parallelizing isotonic regression."
+ * Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147.
+ * Available from http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf
+ */
+class IsotonicRegression private (private var isotonic: Boolean) extends Serializable {
+
+ /**
+ * Constructs IsotonicRegression instance with default parameter isotonic = true.
+ *
+ * @return New instance of IsotonicRegression.
+ */
+ def this() = this(true)
+
+ /**
+ * Sets the isotonic parameter.
+ *
+ * @param isotonic Isotonic (increasing) or antitonic (decreasing) sequence.
+ * @return This instance of IsotonicRegression.
+ */
+ def setIsotonic(isotonic: Boolean): this.type = {
+ this.isotonic = isotonic
+ this
+ }
+
+ /**
+ * Run IsotonicRegression algorithm to obtain isotonic regression model.
+ *
+ * @param input RDD of tuples (label, feature, weight) where label is dependent variable
+ * for which we calculate isotonic regression, feature is independent variable
+ * and weight represents number of measures with default 1.
+ * If multiple labels share the same feature value then they are ordered before
+ * the algorithm is executed.
+ * @return Isotonic regression model.
+ */
+ def run(input: RDD[(Double, Double, Double)]): IsotonicRegressionModel = {
+ val preprocessedInput = if (isotonic) {
+ input
+ } else {
+ input.map(x => (-x._1, x._2, x._3))
+ }
+
+ val pooled = parallelPoolAdjacentViolators(preprocessedInput)
+
+ val predictions = if (isotonic) pooled.map(_._1) else pooled.map(-_._1)
+ val boundaries = pooled.map(_._2)
+
+ new IsotonicRegressionModel(boundaries, predictions, isotonic)
+ }
+
+ /**
+ * Run pool adjacent violators algorithm to obtain isotonic regression model.
+ *
+ * @param input JavaRDD of tuples (label, feature, weight) where label is dependent variable
+ * for which we calculate isotonic regression, feature is independent variable
+ * and weight represents number of measures with default 1.
+ * If multiple labels share the same feature value then they are ordered before
+ * the algorithm is executed.
+ * @return Isotonic regression model.
+ */
+ def run(input: JavaRDD[(JDouble, JDouble, JDouble)]): IsotonicRegressionModel = {
+ run(input.rdd.retag.asInstanceOf[RDD[(Double, Double, Double)]])
+ }
+
+ /**
+ * Performs a pool adjacent violators algorithm (PAV).
+ * Uses approach with single processing of data where violators
+ * in previously processed data created by pooling are fixed immediately.
+ * Uses optimization of discovering monotonicity violating sequences (blocks).
+ *
+ * @param input Input data of tuples (label, feature, weight).
+ * @return Result tuples (label, feature, weight) where labels were updated
+ * to form a monotone sequence as per isotonic regression definition.
+ */
+ private def poolAdjacentViolators(
+ input: Array[(Double, Double, Double)]): Array[(Double, Double, Double)] = {
+
+ if (input.isEmpty) {
+ return Array.empty
+ }
+
+ // Pools sub array within given bounds assigning weighted average value to all elements.
+ def pool(input: Array[(Double, Double, Double)], start: Int, end: Int): Unit = {
+ val poolSubArray = input.slice(start, end + 1)
+
+ val weightedSum = poolSubArray.map(lp => lp._1 * lp._3).sum
+ val weight = poolSubArray.map(_._3).sum
+
+ var i = start
+ while (i <= end) {
+ input(i) = (weightedSum / weight, input(i)._2, input(i)._3)
+ i = i + 1
+ }
+ }
+
+ var i = 0
+ while (i < input.length) {
+ var j = i
+
+ // Find monotonicity violating sequence, if any.
+ while (j < input.length - 1 && input(j)._1 > input(j + 1)._1) {
+ j = j + 1
+ }
+
+ // If monotonicity was not violated, move to next data point.
+ if (i == j) {
+ i = i + 1
+ } else {
+ // Otherwise pool the violating sequence
+ // and check if pooling caused monotonicity violation in previously processed points.
+ while (i >= 0 && input(i)._1 > input(i + 1)._1) {
+ pool(input, i, j)
+ i = i - 1
+ }
+
+ i = j
+ }
+ }
+
+ // For points having the same prediction, we only keep two boundary points.
+ val compressed = ArrayBuffer.empty[(Double, Double, Double)]
+
+ var (curLabel, curFeature, curWeight) = input.head
+ var rightBound = curFeature
+ def merge(): Unit = {
+ compressed += ((curLabel, curFeature, curWeight))
+ if (rightBound > curFeature) {
+ compressed += ((curLabel, rightBound, 0.0))
+ }
+ }
+ i = 1
+ while (i < input.length) {
+ val (label, feature, weight) = input(i)
+ if (label == curLabel) {
+ curWeight += weight
+ rightBound = feature
+ } else {
+ merge()
+ curLabel = label
+ curFeature = feature
+ curWeight = weight
+ rightBound = curFeature
+ }
+ i += 1
+ }
+ merge()
+
+ compressed.toArray
+ }
+
+ /**
+ * Performs parallel pool adjacent violators algorithm.
+ * Performs Pool adjacent violators algorithm on each partition and then again on the result.
+ *
+ * @param input Input data of tuples (label, feature, weight).
+ * @return Result tuples (label, feature, weight) where labels were updated
+ * to form a monotone sequence as per isotonic regression definition.
+ */
+ private def parallelPoolAdjacentViolators(
+ input: RDD[(Double, Double, Double)]): Array[(Double, Double, Double)] = {
+ val parallelStepResult = input
+ .sortBy(x => (x._2, x._1))
+ .glom()
+ .flatMap(poolAdjacentViolators)
+ .collect()
+ .sortBy(x => (x._2, x._1)) // Sort again because collect() doesn't promise ordering.
+ poolAdjacentViolators(parallelStepResult)
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java
new file mode 100644
index 0000000000..d38fc91ace
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.regression;
+
+import java.io.Serializable;
+import java.util.List;
+
+import scala.Tuple3;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaDoubleRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+
+public class JavaIsotonicRegressionSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ private List<Tuple3<Double, Double, Double>> generateIsotonicInput(double[] labels) {
+ List<Tuple3<Double, Double, Double>> input = Lists.newArrayList();
+
+ for (int i = 1; i <= labels.length; i++) {
+ input.add(new Tuple3<Double, Double, Double>(labels[i-1], (double) i, 1d));
+ }
+
+ return input;
+ }
+
+ private IsotonicRegressionModel runIsotonicRegression(double[] labels) {
+ JavaRDD<Tuple3<Double, Double, Double>> trainRDD =
+ sc.parallelize(generateIsotonicInput(labels), 2).cache();
+
+ return new IsotonicRegression().run(trainRDD);
+ }
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void testIsotonicRegressionJavaRDD() {
+ IsotonicRegressionModel model =
+ runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
+
+ Assert.assertArrayEquals(
+ new double[] {1, 2, 7d/3, 7d/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1e-14);
+ }
+
+ @Test
+ public void testIsotonicRegressionPredictionsJavaRDD() {
+ IsotonicRegressionModel model =
+ runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
+
+ JavaDoubleRDD testRDD = sc.parallelizeDoubles(Lists.newArrayList(0.0, 1.0, 9.5, 12.0, 13.0));
+ List<Double> predictions = model.predict(testRDD).collect();
+
+ Assert.assertTrue(predictions.get(0) == 1d);
+ Assert.assertTrue(predictions.get(1) == 1d);
+ Assert.assertTrue(predictions.get(2) == 10d);
+ Assert.assertTrue(predictions.get(3) == 12d);
+ Assert.assertTrue(predictions.get(4) == 12d);
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
new file mode 100644
index 0000000000..7ef4524828
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.regression
+
+import org.scalatest.{Matchers, FunSuite}
+
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+
+class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
+
+ private def round(d: Double) = {
+ Math.round(d * 100).toDouble / 100
+ }
+
+ private def generateIsotonicInput(labels: Seq[Double]): Seq[(Double, Double, Double)] = {
+ Seq.tabulate(labels.size)(i => (labels(i), i.toDouble, 1d))
+ }
+
+ private def generateIsotonicInput(
+ labels: Seq[Double],
+ weights: Seq[Double]): Seq[(Double, Double, Double)] = {
+ Seq.tabulate(labels.size)(i => (labels(i), i.toDouble, weights(i)))
+ }
+
+ private def runIsotonicRegression(
+ labels: Seq[Double],
+ weights: Seq[Double],
+ isotonic: Boolean): IsotonicRegressionModel = {
+ val trainRDD = sc.parallelize(generateIsotonicInput(labels, weights)).cache()
+ new IsotonicRegression().setIsotonic(isotonic).run(trainRDD)
+ }
+
+ private def runIsotonicRegression(
+ labels: Seq[Double],
+ isotonic: Boolean): IsotonicRegressionModel = {
+ runIsotonicRegression(labels, Array.fill(labels.size)(1d), isotonic)
+ }
+
+ test("increasing isotonic regression") {
+ /*
+ The following result could be re-produced with sklearn.
+
+ > from sklearn.isotonic import IsotonicRegression
+ > x = range(9)
+ > y = [1, 2, 3, 1, 6, 17, 16, 17, 18]
+ > ir = IsotonicRegression(x, y)
+ > print ir.predict(x)
+
+ array([ 1. , 2. , 2. , 2. , 6. , 16.5, 16.5, 17. , 18. ])
+ */
+ val model = runIsotonicRegression(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18), true)
+
+ assert(Array.tabulate(9)(x => model.predict(x)) === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18))
+
+ assert(model.boundaries === Array(0, 1, 3, 4, 5, 6, 7, 8))
+ assert(model.predictions === Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0))
+ assert(model.isotonic)
+ }
+
+ test("isotonic regression with size 0") {
+ val model = runIsotonicRegression(Seq(), true)
+
+ assert(model.predictions === Array())
+ }
+
+ test("isotonic regression with size 1") {
+ val model = runIsotonicRegression(Seq(1), true)
+
+ assert(model.predictions === Array(1.0))
+ }
+
+ test("isotonic regression strictly increasing sequence") {
+ val model = runIsotonicRegression(Seq(1, 2, 3, 4, 5), true)
+
+ assert(model.predictions === Array(1, 2, 3, 4, 5))
+ }
+
+ test("isotonic regression strictly decreasing sequence") {
+ val model = runIsotonicRegression(Seq(5, 4, 3, 2, 1), true)
+
+ assert(model.boundaries === Array(0, 4))
+ assert(model.predictions === Array(3, 3))
+ }
+
+ test("isotonic regression with last element violating monotonicity") {
+ val model = runIsotonicRegression(Seq(1, 2, 3, 4, 2), true)
+
+ assert(model.boundaries === Array(0, 1, 2, 4))
+ assert(model.predictions === Array(1, 2, 3, 3))
+ }
+
+ test("isotonic regression with first element violating monotonicity") {
+ val model = runIsotonicRegression(Seq(4, 2, 3, 4, 5), true)
+
+ assert(model.boundaries === Array(0, 2, 3, 4))
+ assert(model.predictions === Array(3, 3, 4, 5))
+ }
+
+ test("isotonic regression with negative labels") {
+ val model = runIsotonicRegression(Seq(-1, -2, 0, 1, -1), true)
+
+ assert(model.boundaries === Array(0, 1, 2, 4))
+ assert(model.predictions === Array(-1.5, -1.5, 0, 0))
+ }
+
+ test("isotonic regression with unordered input") {
+ val trainRDD = sc.parallelize(generateIsotonicInput(Seq(1, 2, 3, 4, 5)).reverse, 2).cache()
+
+ val model = new IsotonicRegression().run(trainRDD)
+ assert(model.predictions === Array(1, 2, 3, 4, 5))
+ }
+
+ test("weighted isotonic regression") {
+ val model = runIsotonicRegression(Seq(1, 2, 3, 4, 2), Seq(1, 1, 1, 1, 2), true)
+
+ assert(model.boundaries === Array(0, 1, 2, 4))
+ assert(model.predictions === Array(1, 2, 2.75, 2.75))
+ }
+
+ test("weighted isotonic regression with weights lower than 1") {
+ val model = runIsotonicRegression(Seq(1, 2, 3, 2, 1), Seq(1, 1, 1, 0.1, 0.1), true)
+
+ assert(model.boundaries === Array(0, 1, 2, 4))
+ assert(model.predictions.map(round) === Array(1, 2, 3.3/1.2, 3.3/1.2))
+ }
+
+ test("weighted isotonic regression with negative weights") {
+ val model = runIsotonicRegression(Seq(1, 2, 3, 2, 1), Seq(-1, 1, -3, 1, -5), true)
+
+ assert(model.boundaries === Array(0.0, 1.0, 4.0))
+ assert(model.predictions === Array(1.0, 10.0/6, 10.0/6))
+ }
+
+ test("weighted isotonic regression with zero weights") {
+ val model = runIsotonicRegression(Seq[Double](1, 2, 3, 2, 1), Seq[Double](0, 0, 0, 1, 0), true)
+
+ assert(model.boundaries === Array(0.0, 1.0, 4.0))
+ assert(model.predictions === Array(1, 2, 2))
+ }
+
+ test("isotonic regression prediction") {
+ val model = runIsotonicRegression(Seq(1, 2, 7, 1, 2), true)
+
+ assert(model.predict(-2) === 1)
+ assert(model.predict(-1) === 1)
+ assert(model.predict(0.5) === 1.5)
+ assert(model.predict(0.75) === 1.75)
+ assert(model.predict(1) === 2)
+ assert(model.predict(2) === 10d/3)
+ assert(model.predict(9) === 10d/3)
+ }
+
+ test("isotonic regression prediction with duplicate features") {
+ val trainRDD = sc.parallelize(
+ Seq[(Double, Double, Double)](
+ (2, 1, 1), (1, 1, 1), (4, 2, 1), (2, 2, 1), (6, 3, 1), (5, 3, 1)), 2).cache()
+ val model = new IsotonicRegression().run(trainRDD)
+
+ assert(model.predict(0) === 1)
+ assert(model.predict(1.5) === 2)
+ assert(model.predict(2.5) === 4.5)
+ assert(model.predict(4) === 6)
+ }
+
+ test("antitonic regression prediction with duplicate features") {
+ val trainRDD = sc.parallelize(
+ Seq[(Double, Double, Double)](
+ (5, 1, 1), (6, 1, 1), (2, 2, 1), (4, 2, 1), (1, 3, 1), (2, 3, 1)), 2).cache()
+ val model = new IsotonicRegression().setIsotonic(false).run(trainRDD)
+
+ assert(model.predict(0) === 6)
+ assert(model.predict(1.5) === 4.5)
+ assert(model.predict(2.5) === 2)
+ assert(model.predict(4) === 1)
+ }
+
+ test("isotonic regression RDD prediction") {
+ val model = runIsotonicRegression(Seq(1, 2, 7, 1, 2), true)
+
+ val testRDD = sc.parallelize(List(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0), 2).cache()
+ val predictions = testRDD.map(x => (x, model.predict(x))).collect().sortBy(_._1).map(_._2)
+ assert(predictions === Array(1, 1, 1.5, 1.75, 2, 10.0/3, 10.0/3))
+ }
+
+ test("antitonic regression prediction") {
+ val model = runIsotonicRegression(Seq(7, 5, 3, 5, 1), false)
+
+ assert(model.predict(-2) === 7)
+ assert(model.predict(-1) === 7)
+ assert(model.predict(0.5) === 6)
+ assert(model.predict(0.75) === 5.5)
+ assert(model.predict(1) === 5)
+ assert(model.predict(2) === 4)
+ assert(model.predict(9) === 1)
+ }
+
+ test("model construction") {
+ val model = new IsotonicRegressionModel(Array(0.0, 1.0), Array(1.0, 2.0), isotonic = true)
+ assert(model.predict(-0.5) === 1.0)
+ assert(model.predict(0.0) === 1.0)
+ assert(model.predict(0.5) ~== 1.5 absTol 1e-14)
+ assert(model.predict(1.0) === 2.0)
+ assert(model.predict(1.5) === 2.0)
+
+ intercept[IllegalArgumentException] {
+ // different array sizes.
+ new IsotonicRegressionModel(Array(0.0, 1.0), Array(1.0), isotonic = true)
+ }
+
+ intercept[IllegalArgumentException] {
+ // unordered boundaries
+ new IsotonicRegressionModel(Array(1.0, 0.0), Array(1.0, 2.0), isotonic = true)
+ }
+
+ intercept[IllegalArgumentException] {
+ // unordered predictions (isotonic)
+ new IsotonicRegressionModel(Array(0.0, 1.0), Array(2.0, 1.0), isotonic = true)
+ }
+
+ intercept[IllegalArgumentException] {
+ // unordered predictions (antitonic)
+ new IsotonicRegressionModel(Array(0.0, 1.0), Array(1.0, 2.0), isotonic = false)
+ }
+ }
+}