From fcb3e1862ffe784f39bde467e8d24c1b7ed3afbb Mon Sep 17 00:00:00 2001 From: Basin Date: Wed, 21 Jan 2015 23:06:34 -0800 Subject: [SPARK-5317]Set BoostingStrategy.defaultParams With Enumeration Algo.Classification or Algo.Regression JIRA Issue: https://issues.apache.org/jira/browse/SPARK-5317 When setting the BoostingStrategy.defaultParams("Classification"), It's more straightforward to set it with the Enumeration Algo.Classification, just like BoostingStragety.defaultParams(Algo.Classification). I overload the method BoostingStragety.defaultParams(). Author: Basin Closes #4103 from Peishen-Jia/stragetyAlgo and squashes the following commits: 87bab1c [Basin] Docs and Code documentations updated. 3b72875 [Basin] defaultParams(algoStr: String) call defaultParams(algo: Algo). 7c1e6ee [Basin] Doc of Java updated. algo -> algoStr instead. d5c8a2e [Basin] Merge branch 'stragetyAlgo' of github.com:Peishen-Jia/spark into stragetyAlgo 65f96ce [Basin] mllib-ensembles doc modified. e04a5aa [Basin] boostingstrategy.defaultParam string algo to enumeration. 68cf544 [Basin] mllib-ensembles doc modified. a4aea51 [Basin] boostingstrategy.defaultParam string algo to enumeration. --- .../tree/configuration/BoostingStrategy.scala | 25 +++++++++++++++------- .../spark/mllib/tree/configuration/Strategy.scala | 14 +++++++++--- 2 files changed, 28 insertions(+), 11 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index cf51d041c6..ed8e6a796f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -68,6 +68,15 @@ case class BoostingStrategy( @Experimental object BoostingStrategy { + /** + * Returns default configuration for the boosting algorithm + * @param algo Learning goal. Supported: "Classification" or "Regression" + * @return Configuration for boosting algorithm + */ + def defaultParams(algo: String): BoostingStrategy = { + defaultParams(Algo.fromString(algo)) + } + /** * Returns default configuration for the boosting algorithm * @param algo Learning goal. Supported: @@ -75,15 +84,15 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm */ - def defaultParams(algo: String): BoostingStrategy = { - val treeStrategy = Strategy.defaultStrategy(algo) - treeStrategy.maxDepth = 3 + def defaultParams(algo: Algo): BoostingStrategy = { + val treeStragtegy = Strategy.defaultStategy(algo) + treeStragtegy.maxDepth = 3 algo match { - case "Classification" => - treeStrategy.numClasses = 2 - new BoostingStrategy(treeStrategy, LogLoss) - case "Regression" => - new BoostingStrategy(treeStrategy, SquaredError) + case Algo.Classification => + treeStragtegy.numClasses = 2 + new BoostingStrategy(treeStragtegy, LogLoss) + case Algo.Regression => + new BoostingStrategy(treeStragtegy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index d5cd89ab94..972959885f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -173,11 +173,19 @@ object Strategy { * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algo "Classification" or "Regression" */ - def defaultStrategy(algo: String): Strategy = algo match { - case "Classification" => + def defaultStrategy(algo: String): Strategy = { + defaultStategy(Algo.fromString(algo)) + } + + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo Algo.Classification or Algo.Regression + */ + def defaultStategy(algo: Algo): Strategy = algo match { + case Algo.Classification => new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, numClasses = 2) - case "Regression" => + case Algo.Regression => new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, numClasses = 0) } -- cgit v1.2.3