aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-05-15 11:59:59 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-05-15 11:59:59 -0700
commitd52761d67f42ad4d2ff02d96f0675fb3ab709f38 (patch)
treee46abacebabaaa5135e0b64d87f2ee6b408ac512 /mllib
parent94c9d6f59859ebc77fae112c2c42c64b7a4d7f83 (diff)
downloadspark-d52761d67f42ad4d2ff02d96f0675fb3ab709f38.tar.gz
spark-d52761d67f42ad4d2ff02d96f0675fb3ab709f38.tar.bz2
spark-d52761d67f42ad4d2ff02d96f0675fb3ab709f38.zip
[SPARK-1741][MLLIB] add predict(JavaRDD) to RegressionModel, ClassificationModel, and KMeans
`model.predict` returns a RDD of Scala primitive type (Int/Double), which is recognized as Object in Java. Adding predict(JavaRDD) could make life easier for Java users. Added tests for KMeans, LinearRegression, and NaiveBayes. Will update examples after https://github.com/apache/spark/pull/653 gets merged. cc: @srowen Author: Xiangrui Meng <meng@databricks.com> Closes #670 from mengxr/predict-javardd and squashes the following commits: b77ccd8 [Xiangrui Meng] Merge branch 'master' into predict-javardd 43caac9 [Xiangrui Meng] add predict(JavaRDD) to RegressionModel, ClassificationModel, and KMeans
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala11
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java16
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java14
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java21
6 files changed, 76 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
index 6332301e30..b7a1d90d24 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
@@ -17,9 +17,10 @@
package org.apache.spark.mllib.classification
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
-import org.apache.spark.annotation.Experimental
/**
* :: Experimental ::
@@ -43,4 +44,12 @@ trait ClassificationModel extends Serializable {
* @return predicted category from the trained model
*/
def predict(testData: Vector): Double
+
+ /**
+ * Predict values for examples stored in a JavaRDD.
+ * @param testData JavaRDD representing data points to be predicted
+ * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction
+ */
+ def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
+ predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index ce14b06241..fba21aefaa 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.clustering
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
@@ -40,6 +41,10 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser
points.map(p => KMeans.findClosest(centersWithNorm, new BreezeVectorWithNorm(p))._1)
}
+ /** Maps given points to their cluster indices. */
+ def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
+ predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
+
/**
* Return the K-means cost (sum of squared distances of points to their nearest center) for this
* model on the given data.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
index b27e158b43..64b02f7a6e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
@@ -17,9 +17,10 @@
package org.apache.spark.mllib.regression
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.annotation.Experimental
@Experimental
trait RegressionModel extends Serializable {
@@ -38,4 +39,12 @@ trait RegressionModel extends Serializable {
* @return Double prediction from the trained model
*/
def predict(testData: Vector): Double
+
+ /**
+ * Predict values for examples stored in a JavaRDD.
+ * @param testData JavaRDD representing data points to be predicted
+ * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction
+ */
+ def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
+ predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
index c80b1134ed..743a43a139 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.classification;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.junit.After;
@@ -87,4 +89,18 @@ public class JavaNaiveBayesSuite implements Serializable {
int numAccurate2 = validatePrediction(POINTS, model2);
Assert.assertEquals(POINTS.size(), numAccurate2);
}
+
+ @Test
+ public void testPredictJavaRDD() {
+ JavaRDD<LabeledPoint> examples = sc.parallelize(POINTS, 2).cache();
+ NaiveBayesModel model = NaiveBayes.train(examples.rdd());
+ JavaRDD<Vector> vectors = examples.map(new Function<LabeledPoint, Vector>() {
+ @Override
+ public Vector call(LabeledPoint v) throws Exception {
+ return v.features();
+ }});
+ JavaRDD<Double> predictions = model.predict(vectors);
+ // Should be able to get the first prediction.
+ predictions.first();
+ }
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
index 49a614bd90..0c916ca378 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
@@ -88,4 +88,18 @@ public class JavaKMeansSuite implements Serializable {
.run(data.rdd());
assertEquals(expectedCenter, model.clusterCenters()[0]);
}
+
+ @Test
+ public void testPredictJavaRDD() {
+ List<Vector> points = Lists.newArrayList(
+ Vectors.dense(1.0, 2.0, 6.0),
+ Vectors.dense(1.0, 3.0, 0.0),
+ Vectors.dense(1.0, 4.0, 6.0)
+ );
+ JavaRDD<Vector> data = sc.parallelize(points, 2);
+ KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
+ JavaRDD<Integer> predictions = model.predict(data);
+ // Should be able to get the first prediction.
+ predictions.first();
+ }
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
index 7151e55351..6dc6877691 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java
@@ -25,8 +25,10 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.util.LinearDataGenerator;
public class JavaLinearRegressionSuite implements Serializable {
@@ -92,4 +94,23 @@ public class JavaLinearRegressionSuite implements Serializable {
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
}
+ @Test
+ public void testPredictJavaRDD() {
+ int nPoints = 100;
+ double A = 0.0;
+ double[] weights = {10, 10};
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(
+ LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
+ LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
+ LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
+ JavaRDD<Vector> vectors = testRDD.map(new Function<LabeledPoint, Vector>() {
+ @Override
+ public Vector call(LabeledPoint v) throws Exception {
+ return v.features();
+ }
+ });
+ JavaRDD<Double> predictions = model.predict(vectors);
+ // Should be able to get the first prediction.
+ predictions.first();
+ }
}