From aa519e3199cbacf9d06aa123df41f65d57559080 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 29 Apr 2014 00:41:03 -0700 Subject: [SPARK-1636][MLLIB] Move main methods to examples * `NaiveBayes` -> `SparseNaiveBayes` * `KMeans` -> `DenseKMeans` * `SVMWithSGD` and `LogisticRegerssionWithSGD` -> `BinaryClassification` * `ALS` -> `MovieLensALS` * `LinearRegressionWithSGD`, `LassoWithSGD`, and `RidgeRegressionWithSGD` -> `LinearRegression` * `DecisionTree` -> `DecisionTreeRunner` `scopt` is used for parsing command-line parameters. `scopt` has MIT license and it only depends on `scala-library`. Example help message: ~~~ BinaryClassification: an example app for binary classification. Usage: BinaryClassification [options] --numIterations number of iterations --stepSize initial step size, default: 1.0 --algorithm algorithm (SVM,LR), default: LR --regType regularization type (L1,L2), default: L2 --regParam regularization parameter, default: 0.1 input paths to labeled examples in LIBSVM format ~~~ Author: Xiangrui Meng Closes #584 from mengxr/mllib-main and squashes the following commits: 7b58c60 [Xiangrui Meng] minor 6e35d7e [Xiangrui Meng] make imports explicit and fix code style c6178c9 [Xiangrui Meng] update TS PCA/SVD to use new spark-submit 6acff75 [Xiangrui Meng] use scopt for DecisionTreeRunner be86069 [Xiangrui Meng] use main instead of extending App b3edf68 [Xiangrui Meng] move DecisionTree's main method to examples 8bfaa5a [Xiangrui Meng] change NaiveBayesParams to Params fe23dcb [Xiangrui Meng] remove main from KMeans and add DenseKMeans as an example 67f4448 [Xiangrui Meng] remove main methods from linear regression algorithms and add LinearRegression example b066bbc [Xiangrui Meng] remove main from ALS and add MovieLensALS example b040f3b [Xiangrui Meng] change BinaryClassificationParams to Params 577945b [Xiangrui Meng] remove unused imports from NB 3d299bc [Xiangrui Meng] remove main from LR/SVM and add an example app for binary classification f70878e [Xiangrui Meng] remove main from NaiveBayes and add an example NaiveBayes app 01ec2cd [Xiangrui Meng] Merge branch 'master' into mllib-main 9420692 [Xiangrui Meng] add scopt to examples dependencies (cherry picked from commit 3f38334f441940ed0a5bbf5588ca7f22d3940359) Signed-off-by: Reynold Xin --- examples/pom.xml | 5 + .../examples/mllib/BinaryClassification.scala | 145 +++++++++++++++++++ .../spark/examples/mllib/DecisionTreeRunner.scala | 161 +++++++++++++++++++++ .../apache/spark/examples/mllib/DenseKMeans.scala | 109 ++++++++++++++ .../spark/examples/mllib/LinearRegression.scala | 125 ++++++++++++++++ .../apache/spark/examples/mllib/MovieLensALS.scala | 131 +++++++++++++++++ .../spark/examples/mllib/SparseNaiveBayes.scala | 102 +++++++++++++ .../spark/examples/mllib/TallSkinnyPCA.scala | 12 +- .../spark/examples/mllib/TallSkinnySVD.scala | 12 +- .../mllib/classification/LogisticRegression.scala | 18 +-- .../spark/mllib/classification/NaiveBayes.scala | 22 +-- .../apache/spark/mllib/classification/SVM.scala | 18 +-- .../org/apache/spark/mllib/clustering/KMeans.scala | 25 +--- .../apache/spark/mllib/recommendation/ALS.scala | 45 +----- .../org/apache/spark/mllib/regression/Lasso.scala | 17 --- .../spark/mllib/regression/LinearRegression.scala | 16 -- .../spark/mllib/regression/RidgeRegression.scala | 19 --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 131 +---------------- project/SparkBuild.scala | 3 +- 19 files changed, 795 insertions(+), 321 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala diff --git a/examples/pom.xml b/examples/pom.xml index d4028c88cb..342d7c0742 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -166,6 +166,11 @@ + + com.github.scopt + scopt_${scala.binary.version} + 3.2.0 + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala new file mode 100644 index 0000000000..ec9de022c1 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.log4j.{Level, Logger} +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.classification.{LogisticRegressionWithSGD, SVMWithSGD} +import org.apache.spark.mllib.evaluation.binary.BinaryClassificationMetrics +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.optimization.{SquaredL2Updater, L1Updater} + +/** + * An example app for binary classification. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.BinaryClassification + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object BinaryClassification { + + object Algorithm extends Enumeration { + type Algorithm = Value + val SVM, LR = Value + } + + object RegType extends Enumeration { + type RegType = Value + val L1, L2 = Value + } + + import Algorithm._ + import RegType._ + + case class Params( + input: String = null, + numIterations: Int = 100, + stepSize: Double = 1.0, + algorithm: Algorithm = LR, + regType: RegType = L2, + regParam: Double = 0.1) + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("BinaryClassification") { + head("BinaryClassification: an example app for binary classification.") + opt[Int]("numIterations") + .text("number of iterations") + .action((x, c) => c.copy(numIterations = x)) + opt[Double]("stepSize") + .text(s"initial step size, default: ${defaultParams.stepSize}") + .action((x, c) => c.copy(stepSize = x)) + opt[String]("algorithm") + .text(s"algorithm (${Algorithm.values.mkString(",")}), " + + s"default: ${defaultParams.algorithm}") + .action((x, c) => c.copy(algorithm = Algorithm.withName(x))) + opt[String]("regType") + .text(s"regularization type (${RegType.values.mkString(",")}), " + + s"default: ${defaultParams.regType}") + .action((x, c) => c.copy(regType = RegType.withName(x))) + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + arg[String]("") + .required() + .text("input paths to labeled examples in LIBSVM format") + .action((x, c) => c.copy(input = x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"BinaryClassification with $params") + val sc = new SparkContext(conf) + + Logger.getRootLogger.setLevel(Level.WARN) + + val examples = MLUtils.loadLibSVMData(sc, params.input).cache() + + val splits = examples.randomSplit(Array(0.8, 0.2)) + val training = splits(0).cache() + val test = splits(1).cache() + + val numTraining = training.count() + val numTest = test.count() + println(s"Training: $numTraining, test: $numTest.") + + examples.unpersist(blocking = false) + + val updater = params.regType match { + case L1 => new L1Updater() + case L2 => new SquaredL2Updater() + } + + val model = params.algorithm match { + case LR => + val algorithm = new LogisticRegressionWithSGD() + algorithm.optimizer + .setNumIterations(params.numIterations) + .setStepSize(params.stepSize) + .setUpdater(updater) + .setRegParam(params.regParam) + algorithm.run(training).clearThreshold() + case SVM => + val algorithm = new SVMWithSGD() + algorithm.optimizer + .setNumIterations(params.numIterations) + .setStepSize(params.stepSize) + .setUpdater(updater) + .setRegParam(params.regParam) + algorithm.run(training).clearThreshold() + } + + val prediction = model.predict(test.map(_.features)) + val predictionAndLabel = prediction.zip(test.map(_.label)) + + val metrics = new BinaryClassificationMetrics(predictionAndLabel) + + println(s"Test areaUnderPR = ${metrics.areaUnderPR()}.") + println(s"Test areaUnderROC = ${metrics.areaUnderROC()}.") + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala new file mode 100644 index 0000000000..0bd847d7ba --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree, impurity} +import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD + +/** + * An example runner for decision tree. Run with + * {{{ + * ./bin/spark-example org.apache.spark.examples.mllib.DecisionTreeRunner [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DecisionTreeRunner { + + object ImpurityType extends Enumeration { + type ImpurityType = Value + val Gini, Entropy, Variance = Value + } + + import ImpurityType._ + + case class Params( + input: String = null, + algo: Algo = Classification, + maxDepth: Int = 5, + impurity: ImpurityType = Gini, + maxBins: Int = 20) + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DecisionTreeRunner") { + head("DecisionTreeRunner: an example decision tree app.") + opt[String]("algo") + .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = Algo.withName(x))) + opt[String]("impurity") + .text(s"impurity type (${ImpurityType.values.mkString(",")}), " + + s"default: ${defaultParams.impurity}") + .action((x, c) => c.copy(impurity = ImpurityType.withName(x))) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("maxBins") + .text(s"max number of bins, default: ${defaultParams.maxBins}") + .action((x, c) => c.copy(maxBins = x)) + arg[String]("") + .text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.algo == Classification && + (params.impurity == Gini || params.impurity == Entropy)) { + success + } else if (params.algo == Regression && params.impurity == Variance) { + success + } else { + failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.") + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName("DecisionTreeRunner") + val sc = new SparkContext(conf) + + // Load training data and cache it. + val examples = MLUtils.loadLabeledData(sc, params.input).cache() + + val splits = examples.randomSplit(Array(0.8, 0.2)) + val training = splits(0).cache() + val test = splits(1).cache() + + val numTraining = training.count() + val numTest = test.count() + + println(s"numTraining = $numTraining, numTest = $numTest.") + + examples.unpersist(blocking = false) + + val impurityCalculator = params.impurity match { + case Gini => impurity.Gini + case Entropy => impurity.Entropy + case Variance => impurity.Variance + } + + val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins) + val model = DecisionTree.train(training, strategy) + + if (params.algo == Classification) { + val accuracy = accuracyScore(model, test) + println(s"Test accuracy = $accuracy.") + } + + if (params.algo == Regression) { + val mse = meanSquaredError(model, test) + println(s"Test mean squared error = $mse.") + } + + sc.stop() + } + + /** + * Calculates the classifier accuracy. + */ + private def accuracyScore( + model: DecisionTreeModel, + data: RDD[LabeledPoint], + threshold: Double = 0.5): Double = { + def predictedValue(features: Vector): Double = { + if (model.predict(features) < threshold) 0.0 else 1.0 + } + val correctCount = data.filter(y => predictedValue(y.features) == y.label).count() + val count = data.count() + correctCount.toDouble / count + } + + /** + * Calculates the mean squared error for regression. + */ + private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { + data.map { y => + val err = tree.predict(y.features) - y.label + err * err + }.mean() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala new file mode 100644 index 0000000000..f96bc1bf00 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.log4j.{Level, Logger} +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.clustering.KMeans +import org.apache.spark.mllib.linalg.Vectors + +/** + * An example k-means app. Run with + * {{{ + * ./bin/spark-example org.apache.spark.examples.mllib.DenseKMeans [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DenseKMeans { + + object InitializationMode extends Enumeration { + type InitializationMode = Value + val Random, Parallel = Value + } + + import InitializationMode._ + + case class Params( + input: String = null, + k: Int = -1, + numIterations: Int = 10, + initializationMode: InitializationMode = Parallel) + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DenseKMeans") { + head("DenseKMeans: an example k-means app for dense data.") + opt[Int]('k', "k") + .required() + .text(s"number of clusters, required") + .action((x, c) => c.copy(k = x)) + opt[Int]("numIterations") + .text(s"number of iterations, default; ${defaultParams.numIterations}") + .action((x, c) => c.copy(numIterations = x)) + opt[String]("initMode") + .text(s"initialization mode (${InitializationMode.values.mkString(",")}), " + + s"default: ${defaultParams.initializationMode}") + .action((x, c) => c.copy(initializationMode = InitializationMode.withName(x))) + arg[String]("") + .text("input paths to examples") + .required() + .action((x, c) => c.copy(input = x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"DenseKMeans with $params") + val sc = new SparkContext(conf) + + Logger.getRootLogger.setLevel(Level.WARN) + + val examples = sc.textFile(params.input).map { line => + Vectors.dense(line.split(' ').map(_.toDouble)) + }.cache() + + val numExamples = examples.count() + + println(s"numExamples = $numExamples.") + + val initMode = params.initializationMode match { + case Random => KMeans.RANDOM + case Parallel => KMeans.K_MEANS_PARALLEL + } + + val model = new KMeans() + .setInitializationMode(initMode) + .setK(params.k) + .setMaxIterations(params.numIterations) + .run(examples) + + val cost = model.computeCost(examples) + + println(s"Total cost = $cost.") + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala new file mode 100644 index 0000000000..1723ca6931 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.log4j.{Level, Logger} +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.mllib.util.{MulticlassLabelParser, MLUtils} +import org.apache.spark.mllib.optimization.{SimpleUpdater, SquaredL2Updater, L1Updater} + +/** + * An example app for linear regression. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.LinearRegression + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LinearRegression extends App { + + object RegType extends Enumeration { + type RegType = Value + val NONE, L1, L2 = Value + } + + import RegType._ + + case class Params( + input: String = null, + numIterations: Int = 100, + stepSize: Double = 1.0, + regType: RegType = L2, + regParam: Double = 0.1) + + val defaultParams = Params() + + val parser = new OptionParser[Params]("LinearRegression") { + head("LinearRegression: an example app for linear regression.") + opt[Int]("numIterations") + .text("number of iterations") + .action((x, c) => c.copy(numIterations = x)) + opt[Double]("stepSize") + .text(s"initial step size, default: ${defaultParams.stepSize}") + .action((x, c) => c.copy(stepSize = x)) + opt[String]("regType") + .text(s"regularization type (${RegType.values.mkString(",")}), " + + s"default: ${defaultParams.regType}") + .action((x, c) => c.copy(regType = RegType.withName(x))) + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + arg[String]("") + .required() + .text("input paths to labeled examples in LIBSVM format") + .action((x, c) => c.copy(input = x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + sys.exit(1) + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"LinearRegression with $params") + val sc = new SparkContext(conf) + + Logger.getRootLogger.setLevel(Level.WARN) + + val examples = MLUtils.loadLibSVMData(sc, params.input, MulticlassLabelParser).cache() + + val splits = examples.randomSplit(Array(0.8, 0.2)) + val training = splits(0).cache() + val test = splits(1).cache() + + val numTraining = training.count() + val numTest = test.count() + println(s"Training: $numTraining, test: $numTest.") + + examples.unpersist(blocking = false) + + val updater = params.regType match { + case NONE => new SimpleUpdater() + case L1 => new L1Updater() + case L2 => new SquaredL2Updater() + } + + val algorithm = new LinearRegressionWithSGD() + algorithm.optimizer + .setNumIterations(params.numIterations) + .setStepSize(params.stepSize) + .setUpdater(updater) + .setRegParam(params.regParam) + + val model = algorithm.run(training) + + val prediction = model.predict(test.map(_.features)) + val predictionAndLabel = prediction.zip(test.map(_.label)) + + val loss = predictionAndLabel.map { case (p, l) => + val err = p - l + err * err + }.reduce(_ + _) + val rmse = math.sqrt(loss / numTest) + + println(s"Test RMSE = $rmse.") + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala new file mode 100644 index 0000000000..703f02255b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import com.esotericsoftware.kryo.Kryo +import org.apache.log4j.{Level, Logger} +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.recommendation.{ALS, MatrixFactorizationModel, Rating} +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.{KryoSerializer, KryoRegistrator} + +/** + * An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/). + */ +object MovieLensALS { + + class ALSRegistrator extends KryoRegistrator { + override def registerClasses(kryo: Kryo) { + kryo.register(classOf[Rating]) + } + } + + case class Params( + input: String = null, + kryo: Boolean = false, + numIterations: Int = 20, + lambda: Double = 1.0, + rank: Int = 10) + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("MovieLensALS") { + head("MovieLensALS: an example app for ALS on MovieLens data.") + opt[Int]("rank") + .text(s"rank, default: ${defaultParams.rank}}") + .action((x, c) => c.copy(rank = x)) + opt[Int]("numIterations") + .text(s"number of iterations, default: ${defaultParams.numIterations}") + .action((x, c) => c.copy(numIterations = x)) + opt[Double]("lambda") + .text(s"lambda (smoothing constant), default: ${defaultParams.lambda}") + .action((x, c) => c.copy(lambda = x)) + opt[Unit]("kryo") + .text(s"use Kryo serialization") + .action((_, c) => c.copy(kryo = true)) + arg[String]("") + .required() + .text("input paths to a MovieLens dataset of ratings") + .action((x, c) => c.copy(input = x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + System.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"MovieLensALS with $params") + if (params.kryo) { + conf.set("spark.serializer", classOf[KryoSerializer].getName) + .set("spark.kryo.registrator", classOf[ALSRegistrator].getName) + .set("spark.kryoserializer.buffer.mb", "8") + } + val sc = new SparkContext(conf) + + Logger.getRootLogger.setLevel(Level.WARN) + + val ratings = sc.textFile(params.input).map { line => + val fields = line.split("::") + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble) + }.cache() + + val numRatings = ratings.count() + val numUsers = ratings.map(_.user).distinct().count() + val numMovies = ratings.map(_.product).distinct().count() + + println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + + val splits = ratings.randomSplit(Array(0.8, 0.2)) + val training = splits(0).cache() + val test = splits(1).cache() + + val numTraining = training.count() + val numTest = test.count() + println(s"Training: $numTraining, test: $numTest.") + + ratings.unpersist(blocking = false) + + val model = new ALS() + .setRank(params.rank) + .setIterations(params.numIterations) + .setLambda(params.lambda) + .run(training) + + val rmse = computeRmse(model, test, numTest) + + println(s"Test RMSE = $rmse.") + + sc.stop() + } + + /** Compute RMSE (Root Mean Squared Error). */ + def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], n: Long) = { + val predictions: RDD[Rating] = model.predict(data.map(x => (x.user, x.product))) + val predictionsAndRatings = predictions.map(x => ((x.user, x.product), x.rating)) + .join(data.map(x => ((x.user, x.product), x.rating))) + .values + math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).reduce(_ + _) / n) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala new file mode 100644 index 0000000000..25b6768b8d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.log4j.{Level, Logger} +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.classification.NaiveBayes +import org.apache.spark.mllib.util.{MLUtils, MulticlassLabelParser} + +/** + * An example naive Bayes app. Run with + * {{{ + * ./bin/spark-example org.apache.spark.examples.mllib.SparseNaiveBayes [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object SparseNaiveBayes { + + case class Params( + input: String = null, + minPartitions: Int = 0, + numFeatures: Int = -1, + lambda: Double = 1.0) + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("SparseNaiveBayes") { + head("SparseNaiveBayes: an example naive Bayes app for LIBSVM data.") + opt[Int]("numPartitions") + .text("min number of partitions") + .action((x, c) => c.copy(minPartitions = x)) + opt[Int]("numFeatures") + .text("number of features") + .action((x, c) => c.copy(numFeatures = x)) + opt[Double]("lambda") + .text(s"lambda (smoothing constant), default: ${defaultParams.lambda}") + .action((x, c) => c.copy(lambda = x)) + arg[String]("") + .text("input paths to labeled examples in LIBSVM format") + .required() + .action((x, c) => c.copy(input = x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"SparseNaiveBayes with $params") + val sc = new SparkContext(conf) + + Logger.getRootLogger.setLevel(Level.WARN) + + val minPartitions = + if (params.minPartitions > 0) params.minPartitions else sc.defaultMinPartitions + + val examples = MLUtils.loadLibSVMData(sc, params.input, MulticlassLabelParser, + params.numFeatures, minPartitions) + // Cache examples because it will be used in both training and evaluation. + examples.cache() + + val splits = examples.randomSplit(Array(0.8, 0.2)) + val training = splits(0) + val test = splits(1) + + val numTraining = training.count() + val numTest = test.count() + + println(s"numTraining = $numTraining, numTest = $numTest.") + + val model = new NaiveBayes().setLambda(params.lambda).run(training) + + val prediction = model.predict(test.map(_.features)) + val predictionAndLabel = prediction.zip(test.map(_.label)) + val accuracy = predictionAndLabel.filter(x => x._1 == x._2).count().toDouble / numTest + + println(s"Test accuracy = $accuracy.") + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala index 39e71cdab4..3cd9cb743e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -35,20 +35,16 @@ import org.apache.spark.mllib.linalg.Vectors */ object TallSkinnyPCA { def main(args: Array[String]) { - if (args.length != 2) { - System.err.println("Usage: TallSkinnyPCA ") + if (args.length != 1) { + System.err.println("Usage: TallSkinnyPCA ") System.exit(1) } - val conf = new SparkConf() - .setMaster(args(0)) - .setAppName("TallSkinnyPCA") - .setSparkHome(System.getenv("SPARK_HOME")) - .setJars(SparkContext.jarOfClass(this.getClass).toSeq) + val conf = new SparkConf().setAppName("TallSkinnyPCA") val sc = new SparkContext(conf) // Load and parse the data file. - val rows = sc.textFile(args(1)).map { line => + val rows = sc.textFile(args(0)).map { line => val values = line.split(' ').map(_.toDouble) Vectors.dense(values) } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala index 2b7de2acc6..4d66903186 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -35,20 +35,16 @@ import org.apache.spark.mllib.linalg.Vectors */ object TallSkinnySVD { def main(args: Array[String]) { - if (args.length != 2) { - System.err.println("Usage: TallSkinnySVD ") + if (args.length != 1) { + System.err.println("Usage: TallSkinnySVD ") System.exit(1) } - val conf = new SparkConf() - .setMaster(args(0)) - .setAppName("TallSkinnySVD") - .setSparkHome(System.getenv("SPARK_HOME")) - .setJars(SparkContext.jarOfClass(this.getClass).toSeq) + val conf = new SparkConf().setAppName("TallSkinnySVD") val sc = new SparkContext(conf) // Load and parse the data file. - val rows = sc.textFile(args(1)).map { line => + val rows = sc.textFile(args(0)).map { line => val values = line.split(' ').map(_.toDouble) Vectors.dense(values) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 4f9eaacf67..780e8bae42 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -17,11 +17,10 @@ package org.apache.spark.mllib.classification -import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, MLUtils} +import org.apache.spark.mllib.util.DataValidators import org.apache.spark.rdd.RDD /** @@ -183,19 +182,4 @@ object LogisticRegressionWithSGD { numIterations: Int): LogisticRegressionModel = { train(input, numIterations, 1.0, 1.0) } - - def main(args: Array[String]) { - if (args.length != 4) { - println("Usage: LogisticRegression " + - "") - System.exit(1) - } - val sc = new SparkContext(args(0), "LogisticRegression") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) - println("Weights: " + model.weights) - println("Intercept: " + model.intercept) - - sc.stop() - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 18658850a2..f6f62ce2de 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -20,11 +20,10 @@ package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} import org.apache.spark.annotation.Experimental -import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD /** @@ -158,23 +157,4 @@ object NaiveBayes { def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { new NaiveBayes(lambda).run(input) } - - def main(args: Array[String]) { - if (args.length != 2 && args.length != 3) { - println("Usage: NaiveBayes []") - System.exit(1) - } - val sc = new SparkContext(args(0), "NaiveBayes") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = if (args.length == 2) { - NaiveBayes.train(data) - } else { - NaiveBayes.train(data, args(2).toDouble) - } - - println("Pi\n: " + model.pi) - println("Theta:\n" + model.theta) - - sc.stop() - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 956654b1fe..81b126717e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -17,11 +17,10 @@ package org.apache.spark.mllib.classification -import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, MLUtils} +import org.apache.spark.mllib.util.DataValidators import org.apache.spark.rdd.RDD /** @@ -183,19 +182,4 @@ object SVMWithSGD { def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { train(input, numIterations, 1.0, 1.0, 1.0) } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: SVM ") - System.exit(1) - } - val sc = new SparkContext(args(0), "SVM") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) - - println("Weights: " + model.weights) - println("Intercept: " + model.intercept) - - sc.stop() - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index dee9ef07e4..a64c5d44be 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -21,8 +21,7 @@ import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseVector => BDV, Vector => BV, norm => breezeNorm} -import org.apache.spark.annotation.Experimental -import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLUtils @@ -396,28 +395,6 @@ object KMeans { v2: BreezeVectorWithNorm): Double = { MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) } - - @Experimental - def main(args: Array[String]) { - if (args.length < 4) { - println("Usage: KMeans []") - System.exit(1) - } - val (master, inputFile, k, iters) = (args(0), args(1), args(2).toInt, args(3).toInt) - val runs = if (args.length >= 5) args(4).toInt else 1 - val sc = new SparkContext(master, "KMeans") - val data = sc.textFile(inputFile) - .map(line => Vectors.dense(line.split(' ').map(_.toDouble))) - .cache() - val model = KMeans.train(data, k, iters, runs) - val cost = model.computeCost(data) - println("Cluster centers:") - for (c <- model.clusterCenters) { - println(" " + c) - } - println("Cost: " + cost) - System.exit(0) - } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 60fb73f2b5..2a77e1a9ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -23,15 +23,13 @@ import scala.util.Random import scala.util.Sorting import scala.util.hashing.byteswap32 -import com.esotericsoftware.kryo.Kryo import org.jblas.{DoubleMatrix, SimpleBlas, Solve} import org.apache.spark.annotation.Experimental import org.apache.spark.broadcast.Broadcast -import org.apache.spark.{Logging, HashPartitioner, Partitioner, SparkContext, SparkConf} +import org.apache.spark.{Logging, HashPartitioner, Partitioner} import org.apache.spark.storage.StorageLevel import org.apache.spark.rdd.RDD -import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.SparkContext._ import org.apache.spark.util.Utils @@ -707,45 +705,4 @@ object ALS { : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0) } - - private class ALSRegistrator extends KryoRegistrator { - override def registerClasses(kryo: Kryo) { - kryo.register(classOf[Rating]) - } - } - - def main(args: Array[String]) { - if (args.length < 5 || args.length > 9) { - println("Usage: ALS " + - "[] [] [] []") - System.exit(1) - } - val (master, ratingsFile, rank, iters, outputDir) = - (args(0), args(1), args(2).toInt, args(3).toInt, args(4)) - val lambda = if (args.length >= 6) args(5).toDouble else 0.01 - val implicitPrefs = if (args.length >= 7) args(6).toBoolean else false - val alpha = if (args.length >= 8) args(7).toDouble else 1 - val blocks = if (args.length == 9) args(8).toInt else -1 - val conf = new SparkConf() - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.kryo.registrator", classOf[ALSRegistrator].getName) - .set("spark.kryo.referenceTracking", "false") - .set("spark.kryoserializer.buffer.mb", "8") - .set("spark.locality.wait", "10000") - val sc = new SparkContext(master, "ALS", conf) - - val ratings = sc.textFile(ratingsFile).map { line => - val fields = line.split(',') - Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble) - } - val model = new ALS(rank = rank, iterations = iters, lambda = lambda, - numBlocks = blocks, implicitPrefs = implicitPrefs, alpha = alpha).run(ratings) - - model.userFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") } - .saveAsTextFile(outputDir + "/userFeatures") - model.productFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") } - .saveAsTextFile(outputDir + "/productFeatures") - println("Final user/product features written to " + outputDir) - sc.stop() - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 5f0812fd2e..0e6fb1b1ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -17,10 +17,8 @@ package org.apache.spark.mllib.regression -import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ -import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD /** @@ -173,19 +171,4 @@ object LassoWithSGD { numIterations: Int): LassoModel = { train(input, numIterations, 1.0, 1.0, 1.0) } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: Lasso ") - System.exit(1) - } - val sc = new SparkContext(args(0), "Lasso") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) - - println("Weights: " + model.weights) - println("Intercept: " + model.intercept) - - sc.stop() - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 228fa8db3e..1532ff90d8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -17,11 +17,9 @@ package org.apache.spark.mllib.regression -import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ -import org.apache.spark.mllib.util.MLUtils /** * Regression model trained using LinearRegression. @@ -156,18 +154,4 @@ object LinearRegressionWithSGD { numIterations: Int): LinearRegressionModel = { train(input, numIterations, 1.0, 1.0) } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: LinearRegression ") - System.exit(1) - } - val sc = new SparkContext(args(0), "LinearRegression") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) - println("Weights: " + model.weights) - println("Intercept: " + model.intercept) - - sc.stop() - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index e702027c7c..5f7e25a9b8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -17,10 +17,8 @@ package org.apache.spark.mllib.regression -import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ -import org.apache.spark.mllib.util.MLUtils import org.apache.spark.mllib.linalg.Vector /** @@ -170,21 +168,4 @@ object RidgeRegressionWithSGD { numIterations: Int): RidgeRegressionModel = { train(input, numIterations, 1.0, 1.0, 1.0) } - - def main(args: Array[String]) { - if (args.length != 5) { - println("Usage: RidgeRegression " + - " ") - System.exit(1) - } - val sc = new SparkContext(args(0), "RidgeRegression") - val data = MLUtils.loadLabeledData(sc, args(1)) - val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble, - args(3).toDouble) - - println("Weights: " + model.weights) - println("Intercept: " + model.intercept) - - sc.stop() - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index f68076f426..59ed01debf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -18,18 +18,16 @@ package org.apache.spark.mllib.tree import org.apache.spark.annotation.Experimental -import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.SparkContext._ +import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: Experimental :: @@ -1028,129 +1026,4 @@ object DecisionTree extends Serializable with Logging { throw new UnsupportedOperationException("approximate histogram not supported yet.") } } - - private val usage = """ - Usage: DecisionTreeRunner [slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] - """ - - def main(args: Array[String]) { - - if (args.length < 2) { - System.err.println(usage) - System.exit(1) - } - - val sc = new SparkContext(args(0), "DecisionTree") - - val argList = args.toList.drop(1) - type OptionMap = Map[Symbol, Any] - - def nextOption(map : OptionMap, list: List[String]): OptionMap = { - list match { - case Nil => map - case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail) - case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) - case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) - case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail) - case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string) - , tail) - case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), - tail) - case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) - case option :: tail => logError("Unknown option " + option) - sys.exit(1) - } - } - val options = nextOption(Map(), argList) - logDebug(options.toString()) - - // Load training data. - val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) - - // Identify the type of algorithm. - val algoStr = options.get('algo).get.toString - val algo = algoStr match { - case "Classification" => Classification - case "Regression" => Regression - } - - // Identify the type of impurity. - val impurityStr = options.getOrElse('impurity, - if (algo == Classification) "Gini" else "Variance").toString - val impurity = impurityStr match { - case "Gini" => Gini - case "Entropy" => Entropy - case "Variance" => Variance - } - - val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt - val maxBins = options.getOrElse('maxBins, "100").toString.toInt - - val strategy = new Strategy(algo, impurity, maxDepth, maxBins) - val model = DecisionTree.train(trainData, strategy) - - // Load test data. - val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) - - // Measure algorithm accuracy - if (algo == Classification) { - val accuracy = accuracyScore(model, testData) - logDebug("accuracy = " + accuracy) - } - - if (algo == Regression) { - val mse = meanSquaredError(model, testData) - logDebug("mean square error = " + mse) - } - - sc.stop() - } - - /** - * Load labeled data from a file. The data format used here is - * , ..., - * where , are feature values in Double and is the corresponding label as Double. - * - * @param sc SparkContext - * @param dir Directory to the input data files. - * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is - * the label, and the second element represents the feature values (an array of Double). - */ - private def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { - sc.textFile(dir).map { line => - val parts = line.trim().split(",") - val label = parts(0).toDouble - val features = Vectors.dense(parts.slice(1,parts.length).map(_.toDouble)) - LabeledPoint(label, features) - } - } - - // TODO: Port this method to a generic metrics package. - /** - * Calculates the classifier accuracy. - */ - private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint], - threshold: Double = 0.5): Double = { - def predictedValue(features: Vector) = { - if (model.predict(features) < threshold) 0.0 else 1.0 - } - val correctCount = data.filter(y => predictedValue(y.features) == y.label).count() - val count = data.count() - logDebug("correct prediction count = " + correctCount) - logDebug("data count = " + count) - correctCount.toDouble / count - } - - // TODO: Port this method to a generic metrics package - /** - * Calculates the mean squared error for regression. - */ - private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { - data.map { y => - val err = tree.predict(y.features) - y.label - err * err - }.mean() - } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9674fe9383..ea91ec7021 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -426,7 +426,8 @@ object SparkBuild extends Build { exclude("io.netty", "netty") exclude("jline","jline") exclude("org.apache.cassandra.deps", "avro") - excludeAll(excludeSLF4J) + excludeAll(excludeSLF4J), + "com.github.scopt" %% "scopt" % "3.2.0" ) ) ++ assemblySettings ++ extraAssemblySettings -- cgit v1.2.3