aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2014-12-04 09:57:50 +0800
committerXiangrui Meng <meng@databricks.com>2014-12-04 09:57:50 +0800
commit657a88835d8bf22488b53d50f75281d7dc32442e (patch)
tree3e72a27719b8c03cf2da2d2d7280e0b808a68043 /mllib
parent27ab0b8a03b711e8d86b6167df833f012205ccc7 (diff)
downloadspark-657a88835d8bf22488b53d50f75281d7dc32442e.tar.gz
spark-657a88835d8bf22488b53d50f75281d7dc32442e.tar.bz2
spark-657a88835d8bf22488b53d50f75281d7dc32442e.zip
[SPARK-4580] [SPARK-4610] [mllib] [docs] Documentation for tree ensembles + DecisionTree API fix
Major changes: * Added programming guide sections for tree ensembles * Added examples for tree ensembles * Updated DecisionTree programming guide with more info on parameters * **API change**: Standardized the tree parameter for the number of classes (for classification) Minor changes: * Updated decision tree documentation * Updated existing tree and tree ensemble examples * Use train/test split, and compute test error instead of training error. * Fixed decision_tree_runner.py to actually use the number of classes it computes from data. (small bug fix) Note: I know this is a lot of lines, but most is covered by: * Programming guide sections for gradient boosting and random forests. (The changes are probably best viewed by generating the docs locally.) * New examples (which were copied from the programming guide) * The "numClasses" renaming I have run all examples and relevant unit tests. CC: mengxr manishamde codedeft Author: Joseph K. Bradley <joseph@databricks.com> Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #3461 from jkbradley/ensemble-docs and squashes the following commits: 70a75f3 [Joseph K. Bradley] updated forest vs boosting comparison d1de753 [Joseph K. Bradley] Added note about toString and toDebugString for DecisionTree to migration guide 8e87f8f [Joseph K. Bradley] Combined GBT and RandomForest guides into one ensembles guide 6fab846 [Joseph K. Bradley] small fixes based on review b9f8576 [Joseph K. Bradley] updated decision tree doc 375204c [Joseph K. Bradley] fixed python style 2b60b6e [Joseph K. Bradley] merged Java RandomForest examples into 1 file. added header. Fixed small bug in same example in the programming guide. 706d332 [Joseph K. Bradley] updated python DT runner to print full model if it is small c76c823 [Joseph K. Bradley] added migration guide for mllib abe5ed7 [Joseph K. Bradley] added examples for random forest in Java and Python to examples folder 07fc11d [Joseph K. Bradley] Renamed numClassesForClassification to numClasses everywhere in trees and ensembles. This is a breaking API change, but it was necessary to correct an API inconsistency in Spark 1.1 (where Python DecisionTree used numClasses but Scala used numClassesForClassification). cdfdfbc [Joseph K. Bradley] added examples for GBT 6372a2b [Joseph K. Bradley] updated decision tree examples to use random split. tested all of them. ad3e695 [Joseph K. Bradley] added gbt and random forest to programming guide. still need to update their examples
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala26
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala46
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala14
9 files changed, 71 insertions, 71 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 9f20cd5d00..c4e5fd8e46 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -477,7 +477,7 @@ class PythonMLLibAPI extends Serializable {
algo = algo,
impurity = impurity,
maxDepth = maxDepth,
- numClassesForClassification = numClasses,
+ numClasses = numClasses,
maxBins = maxBins,
categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap,
minInstancesPerNode = minInstancesPerNode,
@@ -513,7 +513,7 @@ class PythonMLLibAPI extends Serializable {
algo = algo,
impurity = impurity,
maxDepth = maxDepth,
- numClassesForClassification = numClasses,
+ numClasses = numClasses,
maxBins = maxBins,
categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap)
val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 3d91867c89..73e7e32c6d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -136,7 +136,7 @@ object DecisionTree extends Serializable with Logging {
* @param impurity impurity criterion used for information gain calculation
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * @param numClassesForClassification number of classes for classification. Default value of 2.
+ * @param numClasses number of classes for classification. Default value of 2.
* @return DecisionTreeModel that can be used for prediction
*/
def train(
@@ -144,8 +144,8 @@ object DecisionTree extends Serializable with Logging {
algo: Algo,
impurity: Impurity,
maxDepth: Int,
- numClassesForClassification: Int): DecisionTreeModel = {
- val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification)
+ numClasses: Int): DecisionTreeModel = {
+ val strategy = new Strategy(algo, impurity, maxDepth, numClasses)
new DecisionTree(strategy).run(input)
}
@@ -164,7 +164,7 @@ object DecisionTree extends Serializable with Logging {
* @param impurity criterion used for information gain calculation
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * @param numClassesForClassification number of classes for classification. Default value of 2.
+ * @param numClasses number of classes for classification. Default value of 2.
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
* @param categoricalFeaturesInfo Map storing arity of categorical features.
@@ -177,11 +177,11 @@ object DecisionTree extends Serializable with Logging {
algo: Algo,
impurity: Impurity,
maxDepth: Int,
- numClassesForClassification: Int,
+ numClasses: Int,
maxBins: Int,
quantileCalculationStrategy: QuantileStrategy,
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
- val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
+ val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo)
new DecisionTree(strategy).run(input)
}
@@ -191,7 +191,7 @@ object DecisionTree extends Serializable with Logging {
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* Labels should take values {0, 1, ..., numClasses-1}.
- * @param numClassesForClassification number of classes for classification.
+ * @param numClasses number of classes for classification.
* @param categoricalFeaturesInfo Map storing arity of categorical features.
* E.g., an entry (n -> k) indicates that feature n is categorical
* with k categories indexed from 0: {0, 1, ..., k-1}.
@@ -206,13 +206,13 @@ object DecisionTree extends Serializable with Logging {
*/
def trainClassifier(
input: RDD[LabeledPoint],
- numClassesForClassification: Int,
+ numClasses: Int,
categoricalFeaturesInfo: Map[Int, Int],
impurity: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
val impurityType = Impurities.fromString(impurity)
- train(input, Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort,
+ train(input, Classification, impurityType, maxDepth, numClasses, maxBins, Sort,
categoricalFeaturesInfo)
}
@@ -221,12 +221,12 @@ object DecisionTree extends Serializable with Logging {
*/
def trainClassifier(
input: JavaRDD[LabeledPoint],
- numClassesForClassification: Int,
+ numClasses: Int,
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
impurity: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
- trainClassifier(input.rdd, numClassesForClassification,
+ trainClassifier(input.rdd, numClasses,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
impurity, maxDepth, maxBins)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 482d339551..e9304b5e5c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -56,7 +56,7 @@ import org.apache.spark.util.Utils
* etc.
* @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
- * Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+ * Supported: "auto", "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt" for classification and
@@ -269,7 +269,7 @@ object RandomForest extends Serializable with Logging {
* @param strategy Parameters for training each tree in the forest.
* @param numTrees Number of trees in the random forest.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
- * Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+ * Supported: "auto", "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt".
@@ -293,13 +293,13 @@ object RandomForest extends Serializable with Logging {
*
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* Labels should take values {0, 1, ..., numClasses-1}.
- * @param numClassesForClassification number of classes for classification.
+ * @param numClasses number of classes for classification.
* @param categoricalFeaturesInfo Map storing arity of categorical features.
* E.g., an entry (n -> k) indicates that feature n is categorical
* with k categories indexed from 0: {0, 1, ..., k-1}.
* @param numTrees Number of trees in the random forest.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
- * Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+ * Supported: "auto", "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt".
@@ -315,7 +315,7 @@ object RandomForest extends Serializable with Logging {
*/
def trainClassifier(
input: RDD[LabeledPoint],
- numClassesForClassification: Int,
+ numClasses: Int,
categoricalFeaturesInfo: Map[Int, Int],
numTrees: Int,
featureSubsetStrategy: String,
@@ -325,7 +325,7 @@ object RandomForest extends Serializable with Logging {
seed: Int = Utils.random.nextInt()): RandomForestModel = {
val impurityType = Impurities.fromString(impurity)
val strategy = new Strategy(Classification, impurityType, maxDepth,
- numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo)
+ numClasses, maxBins, Sort, categoricalFeaturesInfo)
trainClassifier(input, strategy, numTrees, featureSubsetStrategy, seed)
}
@@ -334,7 +334,7 @@ object RandomForest extends Serializable with Logging {
*/
def trainClassifier(
input: JavaRDD[LabeledPoint],
- numClassesForClassification: Int,
+ numClasses: Int,
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
numTrees: Int,
featureSubsetStrategy: String,
@@ -342,7 +342,7 @@ object RandomForest extends Serializable with Logging {
maxDepth: Int,
maxBins: Int,
seed: Int): RandomForestModel = {
- trainClassifier(input.rdd, numClassesForClassification,
+ trainClassifier(input.rdd, numClasses,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
}
@@ -355,7 +355,7 @@ object RandomForest extends Serializable with Logging {
* @param strategy Parameters for training each tree in the forest.
* @param numTrees Number of trees in the random forest.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
- * Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+ * Supported: "auto", "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "onethird".
@@ -384,7 +384,7 @@ object RandomForest extends Serializable with Logging {
* with k categories indexed from 0: {0, 1, ..., k-1}.
* @param numTrees Number of trees in the random forest.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
- * Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+ * Supported: "auto", "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "onethird".
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index e703adbdbf..cf51d041c6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -51,7 +51,7 @@ case class BoostingStrategy(
private[tree] def assertValid(): Unit = {
treeStrategy.algo match {
case Classification =>
- require(treeStrategy.numClassesForClassification == 2,
+ require(treeStrategy.numClasses == 2,
"Only binary classification is supported for boosting.")
case Regression =>
// nothing
@@ -80,12 +80,12 @@ object BoostingStrategy {
treeStrategy.maxDepth = 3
algo match {
case "Classification" =>
- treeStrategy.numClassesForClassification = 2
+ treeStrategy.numClasses = 2
new BoostingStrategy(treeStrategy, LogLoss)
case "Regression" =>
new BoostingStrategy(treeStrategy, SquaredError)
case _ =>
- throw new IllegalArgumentException(s"$algo is not supported by the boosting.")
+ throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index d75f38433c..d5cd89ab94 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -37,7 +37,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]].
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * @param numClassesForClassification Number of classes for classification.
+ * @param numClasses Number of classes for classification.
* (Ignored for regression.)
* Default value is 2 (binary classification).
* @param maxBins Maximum number of bins used for discretizing continuous features and
@@ -73,7 +73,7 @@ class Strategy (
@BeanProperty var algo: Algo,
@BeanProperty var impurity: Impurity,
@BeanProperty var maxDepth: Int,
- @BeanProperty var numClassesForClassification: Int = 2,
+ @BeanProperty var numClasses: Int = 2,
@BeanProperty var maxBins: Int = 32,
@BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,
@BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
@@ -86,7 +86,7 @@ class Strategy (
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {
def isMulticlassClassification =
- algo == Classification && numClassesForClassification > 2
+ algo == Classification && numClasses > 2
def isMulticlassWithCategoricalFeatures
= isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
@@ -97,10 +97,10 @@ class Strategy (
algo: Algo,
impurity: Impurity,
maxDepth: Int,
- numClassesForClassification: Int,
+ numClasses: Int,
maxBins: Int,
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) {
- this(algo, impurity, maxDepth, numClassesForClassification, maxBins, Sort,
+ this(algo, impurity, maxDepth, numClasses, maxBins, Sort,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
}
@@ -117,8 +117,8 @@ class Strategy (
*/
def setCategoricalFeaturesInfo(
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = {
- setCategoricalFeaturesInfo(
- categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
+ this.categoricalFeaturesInfo =
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap
}
/**
@@ -128,9 +128,9 @@ class Strategy (
private[tree] def assertValid(): Unit = {
algo match {
case Classification =>
- require(numClassesForClassification >= 2,
- s"DecisionTree Strategy for Classification must have numClassesForClassification >= 2," +
- s" but numClassesForClassification = $numClassesForClassification.")
+ require(numClasses >= 2,
+ s"DecisionTree Strategy for Classification must have numClasses >= 2," +
+ s" but numClasses = $numClasses.")
require(Set(Gini, Entropy).contains(impurity),
s"DecisionTree Strategy given invalid impurity for Classification: $impurity." +
s" Valid settings: Gini, Entropy")
@@ -160,7 +160,7 @@ class Strategy (
/** Returns a shallow copy of this instance. */
def copy: Strategy = {
- new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
+ new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval)
}
@@ -176,9 +176,9 @@ object Strategy {
def defaultStrategy(algo: String): Strategy = algo match {
case "Classification" =>
new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
- numClassesForClassification = 2)
+ numClasses = 2)
case "Regression" =>
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
- numClassesForClassification = 0)
+ numClasses = 0)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 5bc0f2635c..951733fada 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -110,7 +110,7 @@ private[tree] object DecisionTreeMetadata extends Logging {
val numFeatures = input.take(1)(0).features.size
val numExamples = input.count()
val numClasses = strategy.algo match {
- case Classification => strategy.numClassesForClassification
+ case Classification => strategy.numClasses
case Regression => 0
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 972c905ec9..9347eaf922 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -57,7 +57,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
Classification,
Gini,
maxDepth = 2,
- numClassesForClassification = 2,
+ numClasses = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
@@ -81,7 +81,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
Classification,
Gini,
maxDepth = 2,
- numClassesForClassification = 2,
+ numClasses = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
@@ -177,7 +177,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
Classification,
Gini,
maxDepth = 2,
- numClassesForClassification = 100,
+ numClasses = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
@@ -271,7 +271,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
Classification,
Gini,
maxDepth = 2,
- numClassesForClassification = 100,
+ numClasses = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
// 2^(10-1) - 1 > 100, so categorical features will be ordered
@@ -295,7 +295,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
val strategy = new Strategy(
Classification,
Gini,
- numClassesForClassification = 2,
+ numClasses = 2,
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
@@ -377,7 +377,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, maxDepth = 3,
- numClassesForClassification = 2, maxBins = 100)
+ numClasses = 2, maxBins = 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -401,7 +401,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, maxDepth = 3,
- numClassesForClassification = 2, maxBins = 100)
+ numClasses = 2, maxBins = 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -426,7 +426,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
- numClassesForClassification = 2, maxBins = 100)
+ numClasses = 2, maxBins = 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -451,7 +451,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
- numClassesForClassification = 2, maxBins = 100)
+ numClasses = 2, maxBins = 100)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -485,7 +485,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
// Train a 1-node model
val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
- numClassesForClassification = 2, maxBins = 100)
+ numClasses = 2, maxBins = 100)
val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
val rootNode1 = modelOneNode.topNode.deepCopy()
val rootNode2 = modelOneNode.topNode.deepCopy()
@@ -545,7 +545,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
- numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+ numClasses = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(strategy.isMulticlassClassification)
assert(metadata.isUnordered(featureIndex = 0))
@@ -568,7 +568,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
- numClassesForClassification = 2)
+ numClasses = 2)
val model = DecisionTree.train(rdd, strategy)
DecisionTreeSuite.validateClassifier(model, arr, 1.0)
@@ -585,7 +585,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
- numClassesForClassification = 2)
+ numClasses = 2)
val model = DecisionTree.train(rdd, strategy)
DecisionTreeSuite.validateClassifier(model, arr, 1.0)
@@ -600,7 +600,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
- numClassesForClassification = 3, maxBins = maxBins,
+ numClasses = 3, maxBins = maxBins,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
@@ -629,7 +629,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
- numClassesForClassification = 3, maxBins = 100)
+ numClasses = 3, maxBins = 100)
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
@@ -650,7 +650,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
- numClassesForClassification = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3))
+ numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3))
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(metadata.isUnordered(featureIndex = 0))
@@ -671,7 +671,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
- numClassesForClassification = 3, maxBins = 100,
+ numClasses = 3, maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
@@ -692,7 +692,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
- numClassesForClassification = 3, maxBins = 10,
+ numClasses = 3, maxBins = 10,
categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification)
@@ -708,7 +708,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
- maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2)
+ maxDepth = 2, numClasses = 2, minInstancesPerNode = 2)
val model = DecisionTree.train(rdd, strategy)
assert(model.topNode.isLeaf)
@@ -737,7 +737,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2),
- numClassesForClassification = 2, minInstancesPerNode = 2)
+ numClasses = 2, minInstancesPerNode = 2)
val rootNode = DecisionTree.train(rdd, strategy).topNode
@@ -755,7 +755,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
val input = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClassesForClassification = 2, minInfoGain = 1.0)
+ numClasses = 2, minInfoGain = 1.0)
val model = DecisionTree.train(input, strategy)
assert(model.topNode.isLeaf)
@@ -781,7 +781,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
val input = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
- numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
@@ -824,7 +824,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
val input = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
- numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index d4d54cf4c9..3aa97e5446 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -100,7 +100,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = Map.empty,
+ numClasses = 2, categoricalFeaturesInfo = Map.empty,
subsamplingRate = subsamplingRate)
val boostingStrategy =
new BoostingStrategy(treeStrategy, LogLoss, numIterations, learningRate)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 90a8c2dfda..f7f0f20c6c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -57,7 +57,7 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
binaryClassificationTestWithContinuousFeatures(strategy)
}
@@ -65,7 +65,7 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
useNodeIdCache = true)
binaryClassificationTestWithContinuousFeatures(strategy)
}
@@ -93,7 +93,7 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val strategy = new Strategy(algo = Regression, impurity = Variance,
- maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
+ maxDepth = 2, maxBins = 10, numClasses = 2,
categoricalFeaturesInfo = categoricalFeaturesInfo)
regressionTestWithContinuousFeatures(strategy)
}
@@ -102,7 +102,7 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val strategy = new Strategy(algo = Regression, impurity = Variance,
- maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
+ maxDepth = 2, maxBins = 10, numClasses = 2,
categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
regressionTestWithContinuousFeatures(strategy)
}
@@ -169,14 +169,14 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
test("Binary classification with continuous features: subsampling features") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
}
test("Binary classification with continuous features and node Id cache: subsampling features") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
useNodeIdCache = true)
binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
}
@@ -191,7 +191,7 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
val input = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
- numClassesForClassification = 3, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ numClasses = 3, categoricalFeaturesInfo = categoricalFeaturesInfo)
val model = RandomForest.trainClassifier(input, strategy, numTrees = 2,
featureSubsetStrategy = "sqrt", seed = 12345)
EnsembleTestHelper.validateClassifier(model, arr, 1.0)