From 04fa1223ee69760f0d23b40e56f4b036aa301879 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 30 Jun 2014 16:03:38 -0700 Subject: SPARK-2293. Replace RDD.zip usage by map with predict inside. This is the only occurrence of this pattern in the examples that needs to be replaced. It only addresses the example change. Author: Sean Owen Closes #1250 from srowen/SPARK-2293 and squashes the following commits: 6b1b28c [Sean Owen] Compute prediction-and-label RDD directly rather than by zipping, for efficiency --- docs/mllib-naive-bayes.md | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) (limited to 'docs') diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 4b3a7cab32..1d1d7dcf6f 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -51,9 +51,8 @@ val training = splits(0) val test = splits(1) val model = NaiveBayes.train(training, lambda = 1.0) -val prediction = model.predict(test.map(_.features)) -val predictionAndLabel = prediction.zip(test.map(_.label)) +val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() {% endhighlight %} @@ -71,6 +70,7 @@ can be used for evaluation and prediction. import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.classification.NaiveBayes; import org.apache.spark.mllib.classification.NaiveBayesModel; import org.apache.spark.mllib.regression.LabeledPoint; @@ -81,18 +81,12 @@ JavaRDD test = ... // test set final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0); -JavaRDD prediction = - test.map(new Function() { - @Override public Double call(LabeledPoint p) { - return model.predict(p.features()); - } - }); JavaPairRDD predictionAndLabel = - prediction.zip(test.map(new Function() { - @Override public Double call(LabeledPoint p) { - return p.label(); + test.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); } - })); + }); double accuracy = 1.0 * predictionAndLabel.filter(new Function, Boolean>() { @Override public Boolean call(Tuple2 pl) { return pl._1() == pl._2(); -- cgit v1.2.3