From d52761d67f42ad4d2ff02d96f0675fb3ab709f38 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 15 May 2014 11:59:59 -0700 Subject: [SPARK-1741][MLLIB] add predict(JavaRDD) to RegressionModel, ClassificationModel, and KMeans `model.predict` returns a RDD of Scala primitive type (Int/Double), which is recognized as Object in Java. Adding predict(JavaRDD) could make life easier for Java users. Added tests for KMeans, LinearRegression, and NaiveBayes. Will update examples after https://github.com/apache/spark/pull/653 gets merged. cc: @srowen Author: Xiangrui Meng Closes #670 from mengxr/predict-javardd and squashes the following commits: b77ccd8 [Xiangrui Meng] Merge branch 'master' into predict-javardd 43caac9 [Xiangrui Meng] add predict(JavaRDD) to RegressionModel, ClassificationModel, and KMeans --- .../mllib/classification/ClassificationModel.scala | 11 ++++++++++- .../apache/spark/mllib/clustering/KMeansModel.scala | 5 +++++ .../spark/mllib/regression/RegressionModel.scala | 11 ++++++++++- .../mllib/classification/JavaNaiveBayesSuite.java | 16 ++++++++++++++++ .../spark/mllib/clustering/JavaKMeansSuite.java | 14 ++++++++++++++ .../mllib/regression/JavaLinearRegressionSuite.java | 21 +++++++++++++++++++++ 6 files changed, 76 insertions(+), 2 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index 6332301e30..b7a1d90d24 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -17,9 +17,10 @@ package org.apache.spark.mllib.classification +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD -import org.apache.spark.annotation.Experimental /** * :: Experimental :: @@ -43,4 +44,12 @@ trait ClassificationModel extends Serializable { * @return predicted category from the trained model */ def predict(testData: Vector): Double + + /** + * Predict values for examples stored in a JavaRDD. + * @param testData JavaRDD representing data points to be predicted + * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction + */ + def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = + predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index ce14b06241..fba21aefaa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.clustering +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector @@ -40,6 +41,10 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser points.map(p => KMeans.findClosest(centersWithNorm, new BreezeVectorWithNorm(p))._1) } + /** Maps given points to their cluster indices. */ + def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = + predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] + /** * Return the K-means cost (sum of squared distances of points to their nearest center) for this * model on the given data. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index b27e158b43..64b02f7a6e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -17,9 +17,10 @@ package org.apache.spark.mllib.regression +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.annotation.Experimental @Experimental trait RegressionModel extends Serializable { @@ -38,4 +39,12 @@ trait RegressionModel extends Serializable { * @return Double prediction from the trained model */ def predict(testData: Vector): Double + + /** + * Predict values for examples stored in a JavaRDD. + * @param testData JavaRDD representing data points to be predicted + * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction + */ + def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = + predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index c80b1134ed..743a43a139 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -19,6 +19,8 @@ package org.apache.spark.mllib.classification; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.junit.After; @@ -87,4 +89,18 @@ public class JavaNaiveBayesSuite implements Serializable { int numAccurate2 = validatePrediction(POINTS, model2); Assert.assertEquals(POINTS.size(), numAccurate2); } + + @Test + public void testPredictJavaRDD() { + JavaRDD examples = sc.parallelize(POINTS, 2).cache(); + NaiveBayesModel model = NaiveBayes.train(examples.rdd()); + JavaRDD vectors = examples.map(new Function() { + @Override + public Vector call(LabeledPoint v) throws Exception { + return v.features(); + }}); + JavaRDD predictions = model.predict(vectors); + // Should be able to get the first prediction. + predictions.first(); + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java index 49a614bd90..0c916ca378 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java @@ -88,4 +88,18 @@ public class JavaKMeansSuite implements Serializable { .run(data.rdd()); assertEquals(expectedCenter, model.clusterCenters()[0]); } + + @Test + public void testPredictJavaRDD() { + List points = Lists.newArrayList( + Vectors.dense(1.0, 2.0, 6.0), + Vectors.dense(1.0, 3.0, 0.0), + Vectors.dense(1.0, 4.0, 6.0) + ); + JavaRDD data = sc.parallelize(points, 2); + KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd()); + JavaRDD predictions = model.predict(data); + // Should be able to get the first prediction. + predictions.first(); + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java index 7151e55351..6dc6877691 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java @@ -25,8 +25,10 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.util.LinearDataGenerator; public class JavaLinearRegressionSuite implements Serializable { @@ -92,4 +94,23 @@ public class JavaLinearRegressionSuite implements Serializable { Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); } + @Test + public void testPredictJavaRDD() { + int nPoints = 100; + double A = 0.0; + double[] weights = {10, 10}; + JavaRDD testRDD = sc.parallelize( + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); + LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); + LinearRegressionModel model = linSGDImpl.run(testRDD.rdd()); + JavaRDD vectors = testRDD.map(new Function() { + @Override + public Vector call(LabeledPoint v) throws Exception { + return v.features(); + } + }); + JavaRDD predictions = model.predict(vectors); + // Should be able to get the first prediction. + predictions.first(); + } } -- cgit v1.2.3