aboutsummaryrefslogtreecommitdiff
path: root/examples/src
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-10-08 14:23:21 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-08 14:23:21 -0700
commitb92bd5a2f29f7a9ce270540b6a828fa7ff205cbe (patch)
tree60d6d39cda43386487017bb9fe33d53e926975c4 /examples/src
parentbc4418727b40c9b6ba5194ead6e2698539272280 (diff)
downloadspark-b92bd5a2f29f7a9ce270540b6a828fa7ff205cbe.tar.gz
spark-b92bd5a2f29f7a9ce270540b6a828fa7ff205cbe.tar.bz2
spark-b92bd5a2f29f7a9ce270540b6a828fa7ff205cbe.zip
[SPARK-3841] [mllib] Pretty-print params for ML examples
Provide a parent class for the Params case classes used in many MLlib examples, where the parent class pretty-prints the case class fields: Param1Name Param1Value Param2Name Param2Value ... Using this class will make it easier to print test settings to logs. Also, updated DecisionTreeRunner to print a little more info. CC: mengxr Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #2700 from jkbradley/dtrunner-update and squashes the following commits: cff873f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update 7a08ae4 [Joseph K. Bradley] code review comment updates b4d2043 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update d8228a7 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update 0fc9c64 [Joseph K. Bradley] Added abstract TestParams class for mllib example parameters 12b7798 [Joseph K. Bradley] Added abstract class TestParams for pretty-printing Params values 5f84f03 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update f7441b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dtrunner-update 19eb6fc [Joseph K. Bradley] Updated DecisionTreeRunner to print training time.
Diffstat (limited to 'examples/src')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala53
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala1
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala1
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala15
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala1
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala1
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala2
11 files changed, 75 insertions, 7 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala
new file mode 100644
index 0000000000..ae6057758d
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import scala.reflect.runtime.universe._
+
+/**
+ * Abstract class for parameter case classes.
+ * This overrides the [[toString]] method to print all case class fields by name and value.
+ * @tparam T Concrete parameter class.
+ */
+abstract class AbstractParams[T: TypeTag] {
+
+ private def tag: TypeTag[T] = typeTag[T]
+
+ /**
+ * Finds all case class fields in concrete class instance, and outputs them in JSON-style format:
+ * {
+ * [field name]:\t[field value]\n
+ * [field name]:\t[field value]\n
+ * ...
+ * }
+ */
+ override def toString: String = {
+ val tpe = tag.tpe
+ val allAccessors = tpe.declarations.collect {
+ case m: MethodSymbol if m.isCaseAccessor => m
+ }
+ val mirror = runtimeMirror(getClass.getClassLoader)
+ val instanceMirror = mirror.reflect(this)
+ allAccessors.map { f =>
+ val paramName = f.name.toString
+ val fieldMirror = instanceMirror.reflectField(f)
+ val paramValue = fieldMirror.get
+ s" $paramName:\t$paramValue"
+ }.mkString("{\n", ",\n", "\n}")
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala
index a6f78d2441..1edd2432a0 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala
@@ -55,7 +55,7 @@ object BinaryClassification {
stepSize: Double = 1.0,
algorithm: Algorithm = LR,
regType: RegType = L2,
- regParam: Double = 0.1)
+ regParam: Double = 0.1) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala
index d6b2fe430e..e49129c4e7 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala
@@ -35,6 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext}
object Correlations {
case class Params(input: String = "data/mllib/sample_linear_regression_data.txt")
+ extends AbstractParams[Params]
def main(args: Array[String]) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala
index 6a3b0241ce..cb1abbd18f 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala
@@ -43,6 +43,7 @@ import org.apache.spark.{SparkConf, SparkContext}
*/
object CosineSimilarity {
case class Params(inputFile: String = null, threshold: Double = 0.1)
+ extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 4adc91d2fb..837d059147 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -62,7 +62,7 @@ object DecisionTreeRunner {
minInfoGain: Double = 0.0,
numTrees: Int = 1,
featureSubsetStrategy: String = "auto",
- fracTest: Double = 0.2)
+ fracTest: Double = 0.2) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
@@ -138,9 +138,11 @@ object DecisionTreeRunner {
def run(params: Params) {
- val conf = new SparkConf().setAppName("DecisionTreeRunner")
+ val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params")
val sc = new SparkContext(conf)
+ println(s"DecisionTreeRunner with parameters:\n$params")
+
// Load training data and cache it.
val origExamples = params.dataFormat match {
case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache()
@@ -235,7 +237,10 @@ object DecisionTreeRunner {
minInstancesPerNode = params.minInstancesPerNode,
minInfoGain = params.minInfoGain)
if (params.numTrees == 1) {
+ val startTime = System.nanoTime()
val model = DecisionTree.train(training, strategy)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
if (model.numNodes < 20) {
println(model.toDebugString) // Print full model.
} else {
@@ -259,8 +264,11 @@ object DecisionTreeRunner {
} else {
val randomSeed = Utils.random.nextInt()
if (params.algo == Classification) {
+ val startTime = System.nanoTime()
val model = RandomForest.trainClassifier(training, strategy, params.numTrees,
params.featureSubsetStrategy, randomSeed)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
if (model.totalNumNodes < 30) {
println(model.toDebugString) // Print full model.
} else {
@@ -275,8 +283,11 @@ object DecisionTreeRunner {
println(s"Test accuracy = $testAccuracy")
}
if (params.algo == Regression) {
+ val startTime = System.nanoTime()
val model = RandomForest.trainRegressor(training, strategy, params.numTrees,
params.featureSubsetStrategy, randomSeed)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
if (model.totalNumNodes < 30) {
println(model.toDebugString) // Print full model.
} else {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala
index 89dfa26c22..11e35598ba 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala
@@ -44,7 +44,7 @@ object DenseKMeans {
input: String = null,
k: Int = -1,
numIterations: Int = 10,
- initializationMode: InitializationMode = Parallel)
+ initializationMode: InitializationMode = Parallel) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala
index 05b7d66f8d..e1f9622350 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala
@@ -47,7 +47,7 @@ object LinearRegression extends App {
numIterations: Int = 100,
stepSize: Double = 1.0,
regType: RegType = L2,
- regParam: Double = 0.1)
+ regParam: Double = 0.1) extends AbstractParams[Params]
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
index 98aaedb9d7..fc6678013b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
@@ -55,7 +55,7 @@ object MovieLensALS {
rank: Int = 10,
numUserBlocks: Int = -1,
numProductBlocks: Int = -1,
- implicitPrefs: Boolean = false)
+ implicitPrefs: Boolean = false) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala
index 4532512c01..6e4e2d07f2 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala
@@ -36,6 +36,7 @@ import org.apache.spark.{SparkConf, SparkContext}
object MultivariateSummarizer {
case class Params(input: String = "data/mllib/sample_linear_regression_data.txt")
+ extends AbstractParams[Params]
def main(args: Array[String]) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala
index f01b8266e3..663c12734a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala
@@ -33,6 +33,7 @@ import org.apache.spark.SparkContext._
object SampledRDDs {
case class Params(input: String = "data/mllib/sample_binary_classification_data.txt")
+ extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala
index 952fa2a510..f1ff4e6911 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala
@@ -37,7 +37,7 @@ object SparseNaiveBayes {
input: String = null,
minPartitions: Int = 0,
numFeatures: Int = -1,
- lambda: Double = 1.0)
+ lambda: Double = 1.0) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()