From aa519e3199cbacf9d06aa123df41f65d57559080 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 29 Apr 2014 00:41:03 -0700 Subject: [SPARK-1636][MLLIB] Move main methods to examples * `NaiveBayes` -> `SparseNaiveBayes` * `KMeans` -> `DenseKMeans` * `SVMWithSGD` and `LogisticRegerssionWithSGD` -> `BinaryClassification` * `ALS` -> `MovieLensALS` * `LinearRegressionWithSGD`, `LassoWithSGD`, and `RidgeRegressionWithSGD` -> `LinearRegression` * `DecisionTree` -> `DecisionTreeRunner` `scopt` is used for parsing command-line parameters. `scopt` has MIT license and it only depends on `scala-library`. Example help message: ~~~ BinaryClassification: an example app for binary classification. Usage: BinaryClassification [options] --numIterations number of iterations --stepSize initial step size, default: 1.0 --algorithm algorithm (SVM,LR), default: LR --regType regularization type (L1,L2), default: L2 --regParam regularization parameter, default: 0.1 input paths to labeled examples in LIBSVM format ~~~ Author: Xiangrui Meng Closes #584 from mengxr/mllib-main and squashes the following commits: 7b58c60 [Xiangrui Meng] minor 6e35d7e [Xiangrui Meng] make imports explicit and fix code style c6178c9 [Xiangrui Meng] update TS PCA/SVD to use new spark-submit 6acff75 [Xiangrui Meng] use scopt for DecisionTreeRunner be86069 [Xiangrui Meng] use main instead of extending App b3edf68 [Xiangrui Meng] move DecisionTree's main method to examples 8bfaa5a [Xiangrui Meng] change NaiveBayesParams to Params fe23dcb [Xiangrui Meng] remove main from KMeans and add DenseKMeans as an example 67f4448 [Xiangrui Meng] remove main methods from linear regression algorithms and add LinearRegression example b066bbc [Xiangrui Meng] remove main from ALS and add MovieLensALS example b040f3b [Xiangrui Meng] change BinaryClassificationParams to Params 577945b [Xiangrui Meng] remove unused imports from NB 3d299bc [Xiangrui Meng] remove main from LR/SVM and add an example app for binary classification f70878e [Xiangrui Meng] remove main from NaiveBayes and add an example NaiveBayes app 01ec2cd [Xiangrui Meng] Merge branch 'master' into mllib-main 9420692 [Xiangrui Meng] add scopt to examples dependencies (cherry picked from commit 3f38334f441940ed0a5bbf5588ca7f22d3940359) Signed-off-by: Reynold Xin --- .../mllib/classification/LogisticRegression.scala | 18 +-- .../spark/mllib/classification/NaiveBayes.scala | 22 +--- .../apache/spark/mllib/classification/SVM.scala | 18 +-- .../org/apache/spark/mllib/clustering/KMeans.scala | 25 +--- .../apache/spark/mllib/recommendation/ALS.scala | 45 +------ .../org/apache/spark/mllib/regression/Lasso.scala | 17 --- .../spark/mllib/regression/LinearRegression.scala | 16 --- .../spark/mllib/regression/RidgeRegression.scala | 19 --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 131 +-------------------- 9 files changed, 7 insertions(+), 304 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 4f9eaacf67..780e8bae42 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -17,11 +17,10 @@ package org.apache.spark.mllib.classification -import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, MLUtils} +import org.apache.spark.mllib.util.DataValidators import org.apache.spark.rdd.RDD /** @@ -183,19 +182,4 @@ object LogisticRegressionWithSGD { numIterations: Int): LogisticRegressionModel = { train(input, numIterations, 1.0, 1.0) } - - def main(args: Array[String]) { - if (args.length != 4) { - println("Usage: LogisticRegression " + - "") - System.exit(1) - } - val sc = new SparkContext(args(0), "LogisticRegression") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) - println("Weights: " + model.weights) - println("Intercept: " + model.intercept) - - sc.stop() - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 18658850a2..f6f62ce2de 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -20,11 +20,10 @@ package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} import org.apache.spark.annotation.Experimental -import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD /** @@ -158,23 +157,4 @@ object NaiveBayes { def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { new NaiveBayes(lambda).run(input) } - - def main(args: Array[String]) { - if (args.length != 2 && args.length != 3) { - println("Usage: NaiveBayes []") - System.exit(1) - } - val sc = new SparkContext(args(0), "NaiveBayes") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = if (args.length == 2) { - NaiveBayes.train(data) - } else { - NaiveBayes.train(data, args(2).toDouble) - } - - println("Pi\n: " + model.pi) - println("Theta:\n" + model.theta) - - sc.stop() - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 956654b1fe..81b126717e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -17,11 +17,10 @@ package org.apache.spark.mllib.classification -import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, MLUtils} +import org.apache.spark.mllib.util.DataValidators import org.apache.spark.rdd.RDD /** @@ -183,19 +182,4 @@ object SVMWithSGD { def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { train(input, numIterations, 1.0, 1.0, 1.0) } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: SVM ") - System.exit(1) - } - val sc = new SparkContext(args(0), "SVM") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) - - println("Weights: " + model.weights) - println("Intercept: " + model.intercept) - - sc.stop() - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index dee9ef07e4..a64c5d44be 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -21,8 +21,7 @@ import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseVector => BDV, Vector => BV, norm => breezeNorm} -import org.apache.spark.annotation.Experimental -import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLUtils @@ -396,28 +395,6 @@ object KMeans { v2: BreezeVectorWithNorm): Double = { MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) } - - @Experimental - def main(args: Array[String]) { - if (args.length < 4) { - println("Usage: KMeans []") - System.exit(1) - } - val (master, inputFile, k, iters) = (args(0), args(1), args(2).toInt, args(3).toInt) - val runs = if (args.length >= 5) args(4).toInt else 1 - val sc = new SparkContext(master, "KMeans") - val data = sc.textFile(inputFile) - .map(line => Vectors.dense(line.split(' ').map(_.toDouble))) - .cache() - val model = KMeans.train(data, k, iters, runs) - val cost = model.computeCost(data) - println("Cluster centers:") - for (c <- model.clusterCenters) { - println(" " + c) - } - println("Cost: " + cost) - System.exit(0) - } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 60fb73f2b5..2a77e1a9ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -23,15 +23,13 @@ import scala.util.Random import scala.util.Sorting import scala.util.hashing.byteswap32 -import com.esotericsoftware.kryo.Kryo import org.jblas.{DoubleMatrix, SimpleBlas, Solve} import org.apache.spark.annotation.Experimental import org.apache.spark.broadcast.Broadcast -import org.apache.spark.{Logging, HashPartitioner, Partitioner, SparkContext, SparkConf} +import org.apache.spark.{Logging, HashPartitioner, Partitioner} import org.apache.spark.storage.StorageLevel import org.apache.spark.rdd.RDD -import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.SparkContext._ import org.apache.spark.util.Utils @@ -707,45 +705,4 @@ object ALS { : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0) } - - private class ALSRegistrator extends KryoRegistrator { - override def registerClasses(kryo: Kryo) { - kryo.register(classOf[Rating]) - } - } - - def main(args: Array[String]) { - if (args.length < 5 || args.length > 9) { - println("Usage: ALS " + - "[] [] [] []") - System.exit(1) - } - val (master, ratingsFile, rank, iters, outputDir) = - (args(0), args(1), args(2).toInt, args(3).toInt, args(4)) - val lambda = if (args.length >= 6) args(5).toDouble else 0.01 - val implicitPrefs = if (args.length >= 7) args(6).toBoolean else false - val alpha = if (args.length >= 8) args(7).toDouble else 1 - val blocks = if (args.length == 9) args(8).toInt else -1 - val conf = new SparkConf() - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.kryo.registrator", classOf[ALSRegistrator].getName) - .set("spark.kryo.referenceTracking", "false") - .set("spark.kryoserializer.buffer.mb", "8") - .set("spark.locality.wait", "10000") - val sc = new SparkContext(master, "ALS", conf) - - val ratings = sc.textFile(ratingsFile).map { line => - val fields = line.split(',') - Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble) - } - val model = new ALS(rank = rank, iterations = iters, lambda = lambda, - numBlocks = blocks, implicitPrefs = implicitPrefs, alpha = alpha).run(ratings) - - model.userFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") } - .saveAsTextFile(outputDir + "/userFeatures") - model.productFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") } - .saveAsTextFile(outputDir + "/productFeatures") - println("Final user/product features written to " + outputDir) - sc.stop() - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 5f0812fd2e..0e6fb1b1ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -17,10 +17,8 @@ package org.apache.spark.mllib.regression -import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ -import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD /** @@ -173,19 +171,4 @@ object LassoWithSGD { numIterations: Int): LassoModel = { train(input, numIterations, 1.0, 1.0, 1.0) } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: Lasso ") - System.exit(1) - } - val sc = new SparkContext(args(0), "Lasso") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) - - println("Weights: " + model.weights) - println("Intercept: " + model.intercept) - - sc.stop() - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 228fa8db3e..1532ff90d8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -17,11 +17,9 @@ package org.apache.spark.mllib.regression -import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ -import org.apache.spark.mllib.util.MLUtils /** * Regression model trained using LinearRegression. @@ -156,18 +154,4 @@ object LinearRegressionWithSGD { numIterations: Int): LinearRegressionModel = { train(input, numIterations, 1.0, 1.0) } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: LinearRegression ") - System.exit(1) - } - val sc = new SparkContext(args(0), "LinearRegression") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) - println("Weights: " + model.weights) - println("Intercept: " + model.intercept) - - sc.stop() - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index e702027c7c..5f7e25a9b8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -17,10 +17,8 @@ package org.apache.spark.mllib.regression -import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ -import org.apache.spark.mllib.util.MLUtils import org.apache.spark.mllib.linalg.Vector /** @@ -170,21 +168,4 @@ object RidgeRegressionWithSGD { numIterations: Int): RidgeRegressionModel = { train(input, numIterations, 1.0, 1.0, 1.0) } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: RidgeRegression " + - " ") - System.exit(1) - } - val sc = new SparkContext(args(0), "RidgeRegression") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble, - args(3).toDouble) - - println("Weights: " + model.weights) - println("Intercept: " + model.intercept) - - sc.stop() - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index f68076f426..59ed01debf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -18,18 +18,16 @@ package org.apache.spark.mllib.tree import org.apache.spark.annotation.Experimental -import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.SparkContext._ +import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: Experimental :: @@ -1028,129 +1026,4 @@ object DecisionTree extends Serializable with Logging { throw new UnsupportedOperationException("approximate histogram not supported yet.") } } - - private val usage = """ - Usage: DecisionTreeRunner [slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] - """ - - def main(args: Array[String]) { - - if (args.length < 2) { - System.err.println(usage) - System.exit(1) - } - - val sc = new SparkContext(args(0), "DecisionTree") - - val argList = args.toList.drop(1) - type OptionMap = Map[Symbol, Any] - - def nextOption(map : OptionMap, list: List[String]): OptionMap = { - list match { - case Nil => map - case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail) - case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) - case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) - case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail) - case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string) - , tail) - case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), - tail) - case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) - case option :: tail => logError("Unknown option " + option) - sys.exit(1) - } - } - val options = nextOption(Map(), argList) - logDebug(options.toString()) - - // Load training data. - val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) - - // Identify the type of algorithm. - val algoStr = options.get('algo).get.toString - val algo = algoStr match { - case "Classification" => Classification - case "Regression" => Regression - } - - // Identify the type of impurity. - val impurityStr = options.getOrElse('impurity, - if (algo == Classification) "Gini" else "Variance").toString - val impurity = impurityStr match { - case "Gini" => Gini - case "Entropy" => Entropy - case "Variance" => Variance - } - - val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt - val maxBins = options.getOrElse('maxBins, "100").toString.toInt - - val strategy = new Strategy(algo, impurity, maxDepth, maxBins) - val model = DecisionTree.train(trainData, strategy) - - // Load test data. - val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) - - // Measure algorithm accuracy - if (algo == Classification) { - val accuracy = accuracyScore(model, testData) - logDebug("accuracy = " + accuracy) - } - - if (algo == Regression) { - val mse = meanSquaredError(model, testData) - logDebug("mean square error = " + mse) - } - - sc.stop() - } - - /** - * Load labeled data from a file. The data format used here is - * , ..., - * where , are feature values in Double and is the corresponding label as Double. - * - * @param sc SparkContext - * @param dir Directory to the input data files. - * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is - * the label, and the second element represents the feature values (an array of Double). - */ - private def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { - sc.textFile(dir).map { line => - val parts = line.trim().split(",") - val label = parts(0).toDouble - val features = Vectors.dense(parts.slice(1,parts.length).map(_.toDouble)) - LabeledPoint(label, features) - } - } - - // TODO: Port this method to a generic metrics package. - /** - * Calculates the classifier accuracy. - */ - private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint], - threshold: Double = 0.5): Double = { - def predictedValue(features: Vector) = { - if (model.predict(features) < threshold) 0.0 else 1.0 - } - val correctCount = data.filter(y => predictedValue(y.features) == y.label).count() - val count = data.count() - logDebug("correct prediction count = " + correctCount) - logDebug("data count = " + count) - correctCount.toDouble / count - } - - // TODO: Port this method to a generic metrics package - /** - * Calculates the mean squared error for regression. - */ - private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { - data.map { y => - val err = tree.predict(y.features) - y.label - err * err - }.mean() - } } -- cgit v1.2.3