aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-09-08 18:59:57 -0700
committerXiangrui Meng <meng@databricks.com>2014-09-08 18:59:57 -0700
commit50a4fa774a0e8a17d7743b33ce8941bf4041144d (patch)
tree18089ba49e1450cf1b76238c9b435883f7003474 /mllib/src
parent7db53391f1b349d1f49844197b34f94806f5e336 (diff)
downloadspark-50a4fa774a0e8a17d7743b33ce8941bf4041144d.tar.gz
spark-50a4fa774a0e8a17d7743b33ce8941bf4041144d.tar.bz2
spark-50a4fa774a0e8a17d7743b33ce8941bf4041144d.zip
[SPARK-3443][MLLIB] update default values of tree:
Adjust the default values of decision tree, based on the memory requirement discussed in https://github.com/apache/spark/pull/2125 : 1. maxMemoryInMB: 128 -> 256 2. maxBins: 100 -> 32 3. maxDepth: 4 -> 5 (in some example code) jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #2322 from mengxr/tree-defaults and squashes the following commits: cda453a [Xiangrui Meng] fix tests 5900445 [Xiangrui Meng] update comments 8c81831 [Xiangrui Meng] update default values of tree:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala18
3 files changed, 11 insertions, 21 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index dd766c12d2..d1309b2b20 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -330,9 +330,9 @@ object DecisionTree extends Serializable with Logging {
* Supported values: "gini" (recommended) or "entropy".
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * (suggested value: 4)
+ * (suggested value: 5)
* @param maxBins maximum number of bins used for splitting features
- * (suggested value: 100)
+ * (suggested value: 32)
* @return DecisionTreeModel that can be used for prediction
*/
def trainClassifier(
@@ -374,9 +374,9 @@ object DecisionTree extends Serializable with Logging {
* Supported values: "variance".
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * (suggested value: 4)
+ * (suggested value: 5)
* @param maxBins maximum number of bins used for splitting features
- * (suggested value: 100)
+ * (suggested value: 32)
* @return DecisionTreeModel that can be used for prediction
*/
def trainRegressor(
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index cfc8192a85..23f74d5360 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -50,7 +50,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
- * 128 MB.
+ * 256 MB.
*/
@Experimental
class Strategy (
@@ -58,10 +58,10 @@ class Strategy (
val impurity: Impurity,
val maxDepth: Int,
val numClassesForClassification: Int = 2,
- val maxBins: Int = 100,
+ val maxBins: Int = 32,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
- val maxMemoryInMB: Int = 128) extends Serializable {
+ val maxMemoryInMB: Int = 256) extends Serializable {
if (algo == Classification) {
require(numClassesForClassification >= 2)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 8e556c917b..69482f2acb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -31,7 +31,6 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
import org.apache.spark.mllib.util.LocalSparkContext
-
class DecisionTreeSuite extends FunSuite with LocalSparkContext {
def validateClassifier(
@@ -353,8 +352,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
- assert(splits(0).length === 99)
- assert(bins(0).length === 100)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
@@ -381,8 +378,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
- assert(splits(0).length === 99)
- assert(bins(0).length === 100)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
@@ -410,8 +405,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
- assert(splits(0).length === 99)
- assert(bins(0).length === 100)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
@@ -439,8 +432,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
- assert(splits(0).length === 99)
- assert(bins(0).length === 100)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
@@ -464,8 +455,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
- assert(splits(0).length === 99)
- assert(bins(0).length === 100)
// Train a 1-node model
val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
@@ -600,7 +589,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
- numClassesForClassification = 3)
+ numClassesForClassification = 3, maxBins = 100)
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
@@ -626,7 +615,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
- numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
+ numClassesForClassification = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3))
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(metadata.isUnordered(featureIndex = 0))
@@ -652,7 +641,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
- numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
+ numClassesForClassification = 3, maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification)
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(!metadata.isUnordered(featureIndex = 0))