diff options
Diffstat (limited to 'examples/src/main')
-rw-r--r-- | examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala | 322 |
1 files changed, 322 insertions, 0 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala new file mode 100644 index 0000000000..d4cc8dede0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.ml.tree.DecisionTreeModel +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer} +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.{SQLContext, DataFrame} + + +/** + * An example runner for decision trees. Run with + * {{{ + * ./bin/run-example ml.DecisionTreeExample [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DecisionTreeExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "Classification", + maxDepth: Int = 5, + maxBins: Int = 32, + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, + numTrees: Int = 1, + featureSubsetStrategy: String = "auto", + fracTest: Double = 0.2, + cacheNodeIds: Boolean = false, + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DecisionTreeExample") { + head("DecisionTreeExample: an example decision tree app.") + opt[String]("algo") + .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("maxBins") + .text(s"max number of bins, default: ${defaultParams.maxBins}") + .action((x, c) => c.copy(maxBins = x)) + opt[Int]("minInstancesPerNode") + .text(s"min number of instances required at child nodes to create the parent split," + + s" default: ${defaultParams.minInstancesPerNode}") + .action((x, c) => c.copy(minInstancesPerNode = x)) + opt[Double]("minInfoGain") + .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") + .action((x, c) => c.copy(minInfoGain = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[Boolean]("cacheNodeIds") + .text(s"whether to use node Id cache during training, " + + s"default: ${defaultParams.cacheNodeIds}") + .action((x, c) => c.copy(cacheNodeIds = x)) + opt[String]("checkpointDir") + .text(s"checkpoint directory where intermediate node Id caches will be stored, " + + s"default: ${defaultParams.checkpointDir match { + case Some(strVal) => strVal + case None => "None" + }}") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"how often to checkpoint the node Id cache, " + + s"default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("<dataFormat>") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("<input>") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + /** Load a dataset from the given path, using the given format */ + private[ml] def loadData( + sc: SparkContext, + path: String, + format: String, + expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = { + format match { + case "dense" => MLUtils.loadLabeledPoints(sc, path) + case "libsvm" => expectedNumFeatures match { + case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures) + case None => MLUtils.loadLibSVMFile(sc, path) + } + case _ => throw new IllegalArgumentException(s"Bad data format: $format") + } + } + + /** + * Load training and test data from files. + * @param input Path to input dataset. + * @param dataFormat "libsvm" or "dense" + * @param testInput Path to test dataset. + * @param algo Classification or Regression + * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given. + * @return (training dataset, test dataset) + */ + private[ml] def loadDatasets( + sc: SparkContext, + input: String, + dataFormat: String, + testInput: String, + algo: String, + fracTest: Double): (DataFrame, DataFrame) = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Load training data + val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat) + + // Load or create test set + val splits: Array[RDD[LabeledPoint]] = if (testInput != "") { + // Load testInput. + val numFeatures = origExamples.take(1)(0).features.size + val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures)) + Array(origExamples, origTestExamples) + } else { + // Split input into training, test. + origExamples.randomSplit(Array(1.0 - fracTest, fracTest), seed = 12345) + } + + // For classification, convert labels to Strings since we will index them later with + // StringIndexer. + def labelsToStrings(data: DataFrame): DataFrame = { + algo.toLowerCase match { + case "classification" => + data.withColumn("labelString", data("label").cast(StringType)) + case "regression" => + data + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + } + val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache()) + + (dataframes(0), dataframes(1)) + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"DecisionTreeExample with $params") + val sc = new SparkContext(conf) + params.checkpointDir.foreach(sc.setCheckpointDir) + val algo = params.algo.toLowerCase + + println(s"DecisionTreeExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = + loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest) + + val numTraining = training.count() + val numTest = test.count() + val numFeatures = training.select("features").first().getAs[Vector](0).size + println("Loaded data:") + println(s" numTraining = $numTraining, numTest = $numTest") + println(s" numFeatures = $numFeatures") + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + // (1) For classification, re-index classes. + val labelColName = if (algo == "classification") "indexedLabel" else "label" + if (algo == "classification") { + val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName) + stages += labelIndexer + } + // (2) Identify categorical features using VectorIndexer. + // Features with more than maxCategories values will be treated as continuous. + val featuresIndexer = new VectorIndexer().setInputCol("features") + .setOutputCol("indexedFeatures").setMaxCategories(10) + stages += featuresIndexer + // (3) Learn DecisionTree + val dt = algo match { + case "classification" => + new DecisionTreeClassifier().setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + case "regression" => + new DecisionTreeRegressor().setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + stages += dt + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Get the trained Decision Tree from the fitted PipelineModel + val treeModel: DecisionTreeModel = algo match { + case "classification" => + pipelineModel.getModel[DecisionTreeClassificationModel]( + dt.asInstanceOf[DecisionTreeClassifier]) + case "regression" => + pipelineModel.getModel[DecisionTreeRegressionModel](dt.asInstanceOf[DecisionTreeRegressor]) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + if (treeModel.numNodes < 20) { + println(treeModel.toDebugString) // Print full model. + } else { + println(treeModel) // Print model summary. + } + + // Predict on training + val trainingFullPredictions = pipelineModel.transform(training).cache() + val trainingPredictions = trainingFullPredictions.select("prediction") + .map(_.getDouble(0)) + val trainingLabels = trainingFullPredictions.select(labelColName).map(_.getDouble(0)) + // Predict on test data + val testFullPredictions = pipelineModel.transform(test).cache() + val testPredictions = testFullPredictions.select("prediction") + .map(_.getDouble(0)) + val testLabels = testFullPredictions.select(labelColName).map(_.getDouble(0)) + + // For classification, print number of classes for reference. + if (algo == "classification") { + val numClasses = + MetadataUtils.getNumClasses(trainingFullPredictions.schema(labelColName)) match { + case Some(n) => n + case None => throw new RuntimeException( + "DecisionTreeExample had unknown failure when indexing labels for classification.") + } + println(s"numClasses = $numClasses.") + } + + // Evaluate model on training, test data + algo match { + case "classification" => + val trainingAccuracy = + new MulticlassMetrics(trainingPredictions.zip(trainingLabels)).precision + println(s"Train accuracy = $trainingAccuracy") + val testAccuracy = + new MulticlassMetrics(testPredictions.zip(testLabels)).precision + println(s"Test accuracy = $testAccuracy") + case "regression" => + val trainingRMSE = + new RegressionMetrics(trainingPredictions.zip(trainingLabels)).rootMeanSquaredError + println(s"Training root mean squared error (RMSE) = $trainingRMSE") + val testRMSE = + new RegressionMetrics(testPredictions.zip(testLabels)).rootMeanSquaredError + println(s"Test root mean squared error (RMSE) = $testRMSE") + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + sc.stop() + } +} |