aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorBasin <jpsachilles@gmail.com>2015-01-21 23:06:34 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-21 23:06:34 -0800
commitfcb3e1862ffe784f39bde467e8d24c1b7ed3afbb (patch)
tree5355559f7e602aa78512fff8a93816e02e2d9bf1 /mllib
parentca7910d6dd7693be2a675a0d6a6fcc9eb0aaeb5d (diff)
downloadspark-fcb3e1862ffe784f39bde467e8d24c1b7ed3afbb.tar.gz
spark-fcb3e1862ffe784f39bde467e8d24c1b7ed3afbb.tar.bz2
spark-fcb3e1862ffe784f39bde467e8d24c1b7ed3afbb.zip
[SPARK-5317]Set BoostingStrategy.defaultParams With Enumeration Algo.Classification or Algo.Regression
JIRA Issue: https://issues.apache.org/jira/browse/SPARK-5317 When setting the BoostingStrategy.defaultParams("Classification"), It's more straightforward to set it with the Enumeration Algo.Classification, just like BoostingStragety.defaultParams(Algo.Classification). I overload the method BoostingStragety.defaultParams(). Author: Basin <jpsachilles@gmail.com> Closes #4103 from Peishen-Jia/stragetyAlgo and squashes the following commits: 87bab1c [Basin] Docs and Code documentations updated. 3b72875 [Basin] defaultParams(algoStr: String) call defaultParams(algo: Algo). 7c1e6ee [Basin] Doc of Java updated. algo -> algoStr instead. d5c8a2e [Basin] Merge branch 'stragetyAlgo' of github.com:Peishen-Jia/spark into stragetyAlgo 65f96ce [Basin] mllib-ensembles doc modified. e04a5aa [Basin] boostingstrategy.defaultParam string algo to enumeration. 68cf544 [Basin] mllib-ensembles doc modified. a4aea51 [Basin] boostingstrategy.defaultParam string algo to enumeration.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala25
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala14
2 files changed, 28 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index cf51d041c6..ed8e6a796f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -70,20 +70,29 @@ object BoostingStrategy {
/**
* Returns default configuration for the boosting algorithm
+ * @param algo Learning goal. Supported: "Classification" or "Regression"
+ * @return Configuration for boosting algorithm
+ */
+ def defaultParams(algo: String): BoostingStrategy = {
+ defaultParams(Algo.fromString(algo))
+ }
+
+ /**
+ * Returns default configuration for the boosting algorithm
* @param algo Learning goal. Supported:
* [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
* @return Configuration for boosting algorithm
*/
- def defaultParams(algo: String): BoostingStrategy = {
- val treeStrategy = Strategy.defaultStrategy(algo)
- treeStrategy.maxDepth = 3
+ def defaultParams(algo: Algo): BoostingStrategy = {
+ val treeStragtegy = Strategy.defaultStategy(algo)
+ treeStragtegy.maxDepth = 3
algo match {
- case "Classification" =>
- treeStrategy.numClasses = 2
- new BoostingStrategy(treeStrategy, LogLoss)
- case "Regression" =>
- new BoostingStrategy(treeStrategy, SquaredError)
+ case Algo.Classification =>
+ treeStragtegy.numClasses = 2
+ new BoostingStrategy(treeStragtegy, LogLoss)
+ case Algo.Regression =>
+ new BoostingStrategy(treeStragtegy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index d5cd89ab94..972959885f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -173,11 +173,19 @@ object Strategy {
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
* @param algo "Classification" or "Regression"
*/
- def defaultStrategy(algo: String): Strategy = algo match {
- case "Classification" =>
+ def defaultStrategy(algo: String): Strategy = {
+ defaultStategy(Algo.fromString(algo))
+ }
+
+ /**
+ * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
+ * @param algo Algo.Classification or Algo.Regression
+ */
+ def defaultStategy(algo: Algo): Strategy = algo match {
+ case Algo.Classification =>
new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
numClasses = 2)
- case "Regression" =>
+ case Algo.Regression =>
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
numClasses = 0)
}