aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-04-29 00:41:03 -0700
committerReynold Xin <rxin@apache.org>2014-04-29 00:41:15 -0700
commitaa519e3199cbacf9d06aa123df41f65d57559080 (patch)
tree12ffbfdf799130f4c6c2c254823ac5deef23d800 /mllib
parent0995787ac6c71b7786ec68e3ee6f572ad7bf56a3 (diff)
downloadspark-aa519e3199cbacf9d06aa123df41f65d57559080.tar.gz
spark-aa519e3199cbacf9d06aa123df41f65d57559080.tar.bz2
spark-aa519e3199cbacf9d06aa123df41f65d57559080.zip
[SPARK-1636][MLLIB] Move main methods to examples
* `NaiveBayes` -> `SparseNaiveBayes` * `KMeans` -> `DenseKMeans` * `SVMWithSGD` and `LogisticRegerssionWithSGD` -> `BinaryClassification` * `ALS` -> `MovieLensALS` * `LinearRegressionWithSGD`, `LassoWithSGD`, and `RidgeRegressionWithSGD` -> `LinearRegression` * `DecisionTree` -> `DecisionTreeRunner` `scopt` is used for parsing command-line parameters. `scopt` has MIT license and it only depends on `scala-library`. Example help message: ~~~ BinaryClassification: an example app for binary classification. Usage: BinaryClassification [options] <input> --numIterations <value> number of iterations --stepSize <value> initial step size, default: 1.0 --algorithm <value> algorithm (SVM,LR), default: LR --regType <value> regularization type (L1,L2), default: L2 --regParam <value> regularization parameter, default: 0.1 <input> input paths to labeled examples in LIBSVM format ~~~ Author: Xiangrui Meng <meng@databricks.com> Closes #584 from mengxr/mllib-main and squashes the following commits: 7b58c60 [Xiangrui Meng] minor 6e35d7e [Xiangrui Meng] make imports explicit and fix code style c6178c9 [Xiangrui Meng] update TS PCA/SVD to use new spark-submit 6acff75 [Xiangrui Meng] use scopt for DecisionTreeRunner be86069 [Xiangrui Meng] use main instead of extending App b3edf68 [Xiangrui Meng] move DecisionTree's main method to examples 8bfaa5a [Xiangrui Meng] change NaiveBayesParams to Params fe23dcb [Xiangrui Meng] remove main from KMeans and add DenseKMeans as an example 67f4448 [Xiangrui Meng] remove main methods from linear regression algorithms and add LinearRegression example b066bbc [Xiangrui Meng] remove main from ALS and add MovieLensALS example b040f3b [Xiangrui Meng] change BinaryClassificationParams to Params 577945b [Xiangrui Meng] remove unused imports from NB 3d299bc [Xiangrui Meng] remove main from LR/SVM and add an example app for binary classification f70878e [Xiangrui Meng] remove main from NaiveBayes and add an example NaiveBayes app 01ec2cd [Xiangrui Meng] Merge branch 'master' into mllib-main 9420692 [Xiangrui Meng] add scopt to examples dependencies (cherry picked from commit 3f38334f441940ed0a5bbf5588ca7f22d3940359) Signed-off-by: Reynold Xin <rxin@apache.org>
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala25
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala45
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala131
9 files changed, 7 insertions, 304 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 4f9eaacf67..780e8bae42 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -17,11 +17,10 @@
package org.apache.spark.mllib.classification
-import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.{DataValidators, MLUtils}
+import org.apache.spark.mllib.util.DataValidators
import org.apache.spark.rdd.RDD
/**
@@ -183,19 +182,4 @@ object LogisticRegressionWithSGD {
numIterations: Int): LogisticRegressionModel = {
train(input, numIterations, 1.0, 1.0)
}
-
- def main(args: Array[String]) {
- if (args.length != 4) {
- println("Usage: LogisticRegression <master> <input_dir> <step_size> " +
- "<niters>")
- System.exit(1)
- }
- val sc = new SparkContext(args(0), "LogisticRegression")
- val data = MLUtils.loadLabeledData(sc, args(1))
- val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
- println("Weights: " + model.weights)
- println("Intercept: " + model.intercept)
-
- sc.stop()
- }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 18658850a2..f6f62ce2de 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -20,11 +20,10 @@ package org.apache.spark.mllib.classification
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import org.apache.spark.annotation.Experimental
-import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
/**
@@ -158,23 +157,4 @@ object NaiveBayes {
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
new NaiveBayes(lambda).run(input)
}
-
- def main(args: Array[String]) {
- if (args.length != 2 && args.length != 3) {
- println("Usage: NaiveBayes <master> <input_dir> [<lambda>]")
- System.exit(1)
- }
- val sc = new SparkContext(args(0), "NaiveBayes")
- val data = MLUtils.loadLabeledData(sc, args(1))
- val model = if (args.length == 2) {
- NaiveBayes.train(data)
- } else {
- NaiveBayes.train(data, args(2).toDouble)
- }
-
- println("Pi\n: " + model.pi)
- println("Theta:\n" + model.theta)
-
- sc.stop()
- }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index 956654b1fe..81b126717e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -17,11 +17,10 @@
package org.apache.spark.mllib.classification
-import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.{DataValidators, MLUtils}
+import org.apache.spark.mllib.util.DataValidators
import org.apache.spark.rdd.RDD
/**
@@ -183,19 +182,4 @@ object SVMWithSGD {
def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = {
train(input, numIterations, 1.0, 1.0, 1.0)
}
-
- def main(args: Array[String]) {
- if (args.length != 5) {
- println("Usage: SVM <master> <input_dir> <step_size> <regularization_parameter> <niters>")
- System.exit(1)
- }
- val sc = new SparkContext(args(0), "SVM")
- val data = MLUtils.loadLabeledData(sc, args(1))
- val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
-
- println("Weights: " + model.weights)
- println("Intercept: " + model.intercept)
-
- sc.stop()
- }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index dee9ef07e4..a64c5d44be 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -21,8 +21,7 @@ import scala.collection.mutable.ArrayBuffer
import breeze.linalg.{DenseVector => BDV, Vector => BV, norm => breezeNorm}
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLUtils
@@ -396,28 +395,6 @@ object KMeans {
v2: BreezeVectorWithNorm): Double = {
MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
}
-
- @Experimental
- def main(args: Array[String]) {
- if (args.length < 4) {
- println("Usage: KMeans <master> <input_file> <k> <max_iterations> [<runs>]")
- System.exit(1)
- }
- val (master, inputFile, k, iters) = (args(0), args(1), args(2).toInt, args(3).toInt)
- val runs = if (args.length >= 5) args(4).toInt else 1
- val sc = new SparkContext(master, "KMeans")
- val data = sc.textFile(inputFile)
- .map(line => Vectors.dense(line.split(' ').map(_.toDouble)))
- .cache()
- val model = KMeans.train(data, k, iters, runs)
- val cost = model.computeCost(data)
- println("Cluster centers:")
- for (c <- model.clusterCenters) {
- println(" " + c)
- }
- println("Cost: " + cost)
- System.exit(0)
- }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 60fb73f2b5..2a77e1a9ef 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -23,15 +23,13 @@ import scala.util.Random
import scala.util.Sorting
import scala.util.hashing.byteswap32
-import com.esotericsoftware.kryo.Kryo
import org.jblas.{DoubleMatrix, SimpleBlas, Solve}
import org.apache.spark.annotation.Experimental
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.{Logging, HashPartitioner, Partitioner, SparkContext, SparkConf}
+import org.apache.spark.{Logging, HashPartitioner, Partitioner}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.rdd.RDD
-import org.apache.spark.serializer.KryoRegistrator
import org.apache.spark.SparkContext._
import org.apache.spark.util.Utils
@@ -707,45 +705,4 @@ object ALS {
: MatrixFactorizationModel = {
trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0)
}
-
- private class ALSRegistrator extends KryoRegistrator {
- override def registerClasses(kryo: Kryo) {
- kryo.register(classOf[Rating])
- }
- }
-
- def main(args: Array[String]) {
- if (args.length < 5 || args.length > 9) {
- println("Usage: ALS <master> <ratings_file> <rank> <iterations> <output_dir> " +
- "[<lambda>] [<implicitPrefs>] [<alpha>] [<blocks>]")
- System.exit(1)
- }
- val (master, ratingsFile, rank, iters, outputDir) =
- (args(0), args(1), args(2).toInt, args(3).toInt, args(4))
- val lambda = if (args.length >= 6) args(5).toDouble else 0.01
- val implicitPrefs = if (args.length >= 7) args(6).toBoolean else false
- val alpha = if (args.length >= 8) args(7).toDouble else 1
- val blocks = if (args.length == 9) args(8).toInt else -1
- val conf = new SparkConf()
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", classOf[ALSRegistrator].getName)
- .set("spark.kryo.referenceTracking", "false")
- .set("spark.kryoserializer.buffer.mb", "8")
- .set("spark.locality.wait", "10000")
- val sc = new SparkContext(master, "ALS", conf)
-
- val ratings = sc.textFile(ratingsFile).map { line =>
- val fields = line.split(',')
- Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble)
- }
- val model = new ALS(rank = rank, iterations = iters, lambda = lambda,
- numBlocks = blocks, implicitPrefs = implicitPrefs, alpha = alpha).run(ratings)
-
- model.userFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") }
- .saveAsTextFile(outputDir + "/userFeatures")
- model.productFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") }
- .saveAsTextFile(outputDir + "/productFeatures")
- println("Final user/product features written to " + outputDir)
- sc.stop()
- }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index 5f0812fd2e..0e6fb1b1ca 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -17,10 +17,8 @@
package org.apache.spark.mllib.regression
-import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
-import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
/**
@@ -173,19 +171,4 @@ object LassoWithSGD {
numIterations: Int): LassoModel = {
train(input, numIterations, 1.0, 1.0, 1.0)
}
-
- def main(args: Array[String]) {
- if (args.length != 5) {
- println("Usage: Lasso <master> <input_dir> <step_size> <regularization_parameter> <niters>")
- System.exit(1)
- }
- val sc = new SparkContext(args(0), "Lasso")
- val data = MLUtils.loadLabeledData(sc, args(1))
- val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
-
- println("Weights: " + model.weights)
- println("Intercept: " + model.intercept)
-
- sc.stop()
- }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 228fa8db3e..1532ff90d8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -17,11 +17,9 @@
package org.apache.spark.mllib.regression
-import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
-import org.apache.spark.mllib.util.MLUtils
/**
* Regression model trained using LinearRegression.
@@ -156,18 +154,4 @@ object LinearRegressionWithSGD {
numIterations: Int): LinearRegressionModel = {
train(input, numIterations, 1.0, 1.0)
}
-
- def main(args: Array[String]) {
- if (args.length != 5) {
- println("Usage: LinearRegression <master> <input_dir> <step_size> <niters>")
- System.exit(1)
- }
- val sc = new SparkContext(args(0), "LinearRegression")
- val data = MLUtils.loadLabeledData(sc, args(1))
- val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
- println("Weights: " + model.weights)
- println("Intercept: " + model.intercept)
-
- sc.stop()
- }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index e702027c7c..5f7e25a9b8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -17,10 +17,8 @@
package org.apache.spark.mllib.regression
-import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.optimization._
-import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.linalg.Vector
/**
@@ -170,21 +168,4 @@ object RidgeRegressionWithSGD {
numIterations: Int): RidgeRegressionModel = {
train(input, numIterations, 1.0, 1.0, 1.0)
}
-
- def main(args: Array[String]) {
- if (args.length != 5) {
- println("Usage: RidgeRegression <master> <input_dir> <step_size> " +
- "<regularization_parameter> <niters>")
- System.exit(1)
- }
- val sc = new SparkContext(args(0), "RidgeRegression")
- val data = MLUtils.loadLabeledData(sc, args(1))
- val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble,
- args(3).toDouble)
-
- println("Weights: " + model.weights)
- println("Intercept: " + model.intercept)
-
- sc.stop()
- }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index f68076f426..59ed01debf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -18,18 +18,16 @@
package org.apache.spark.mllib.tree
import org.apache.spark.annotation.Experimental
-import org.apache.spark.{Logging, SparkContext}
-import org.apache.spark.SparkContext._
+import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
+import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
/**
* :: Experimental ::
@@ -1028,129 +1026,4 @@ object DecisionTree extends Serializable with Logging {
throw new UnsupportedOperationException("approximate histogram not supported yet.")
}
}
-
- private val usage = """
- Usage: DecisionTreeRunner <master>[slices] --algo <Classification,
- Regression> --trainDataDir path --testDataDir path --maxDepth num [--impurity <Gini,Entropy,
- Variance>] [--maxBins num]
- """
-
- def main(args: Array[String]) {
-
- if (args.length < 2) {
- System.err.println(usage)
- System.exit(1)
- }
-
- val sc = new SparkContext(args(0), "DecisionTree")
-
- val argList = args.toList.drop(1)
- type OptionMap = Map[Symbol, Any]
-
- def nextOption(map : OptionMap, list: List[String]): OptionMap = {
- list match {
- case Nil => map
- case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail)
- case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail)
- case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail)
- case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail)
- case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string)
- , tail)
- case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string),
- tail)
- case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail)
- case option :: tail => logError("Unknown option " + option)
- sys.exit(1)
- }
- }
- val options = nextOption(Map(), argList)
- logDebug(options.toString())
-
- // Load training data.
- val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString)
-
- // Identify the type of algorithm.
- val algoStr = options.get('algo).get.toString
- val algo = algoStr match {
- case "Classification" => Classification
- case "Regression" => Regression
- }
-
- // Identify the type of impurity.
- val impurityStr = options.getOrElse('impurity,
- if (algo == Classification) "Gini" else "Variance").toString
- val impurity = impurityStr match {
- case "Gini" => Gini
- case "Entropy" => Entropy
- case "Variance" => Variance
- }
-
- val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt
- val maxBins = options.getOrElse('maxBins, "100").toString.toInt
-
- val strategy = new Strategy(algo, impurity, maxDepth, maxBins)
- val model = DecisionTree.train(trainData, strategy)
-
- // Load test data.
- val testData = loadLabeledData(sc, options.get('testDataDir).get.toString)
-
- // Measure algorithm accuracy
- if (algo == Classification) {
- val accuracy = accuracyScore(model, testData)
- logDebug("accuracy = " + accuracy)
- }
-
- if (algo == Regression) {
- val mse = meanSquaredError(model, testData)
- logDebug("mean square error = " + mse)
- }
-
- sc.stop()
- }
-
- /**
- * Load labeled data from a file. The data format used here is
- * <L>, <f1> <f2> ...,
- * where <f1>, <f2> are feature values in Double and <L> is the corresponding label as Double.
- *
- * @param sc SparkContext
- * @param dir Directory to the input data files.
- * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is
- * the label, and the second element represents the feature values (an array of Double).
- */
- private def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = {
- sc.textFile(dir).map { line =>
- val parts = line.trim().split(",")
- val label = parts(0).toDouble
- val features = Vectors.dense(parts.slice(1,parts.length).map(_.toDouble))
- LabeledPoint(label, features)
- }
- }
-
- // TODO: Port this method to a generic metrics package.
- /**
- * Calculates the classifier accuracy.
- */
- private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint],
- threshold: Double = 0.5): Double = {
- def predictedValue(features: Vector) = {
- if (model.predict(features) < threshold) 0.0 else 1.0
- }
- val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
- val count = data.count()
- logDebug("correct prediction count = " + correctCount)
- logDebug("data count = " + count)
- correctCount.toDouble / count
- }
-
- // TODO: Port this method to a generic metrics package
- /**
- * Calculates the mean squared error for regression.
- */
- private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
- data.map { y =>
- val err = tree.predict(y.features) - y.label
- err * err
- }.mean()
- }
}