aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-08-07 00:20:38 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-07 00:20:38 -0700
commit8d1dec4fa4798bb48b8947446d306ec9ba6bddb5 (patch)
tree5d16037d5c0d237ae6fd736c4f95e7b6f17a993a /mllib
parent75993a65173172da32bbe98751e8c0f55c17a52e (diff)
downloadspark-8d1dec4fa4798bb48b8947446d306ec9ba6bddb5.tar.gz
spark-8d1dec4fa4798bb48b8947446d306ec9ba6bddb5.tar.bz2
spark-8d1dec4fa4798bb48b8947446d306ec9ba6bddb5.zip
[mllib] DecisionTree Strategy parameter checks
Added some checks to Strategy to print out meaningful error messages when given invalid DecisionTree parameters. CC mengxr Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #1821 from jkbradley/dt-robustness and squashes the following commits: 4dc449a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-robustness 7a61f7b [Joseph K. Bradley] Added some checks to Strategy to print out meaningful error messages when given invalid DecisionTree parameters
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala31
2 files changed, 38 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index c8a8656596..bb50f07be5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -44,6 +44,8 @@ import org.apache.spark.util.random.XORShiftRandom
@Experimental
class DecisionTree (private val strategy: Strategy) extends Serializable with Logging {
+ strategy.assertValid()
+
/**
* Method to train a decision tree model over an RDD
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
@@ -1465,10 +1467,14 @@ object DecisionTree extends Serializable with Logging {
/*
- * Ensure #bins is always greater than the categories. For multiclass classification,
- * #bins should be greater than 2^(maxCategories - 1) - 1.
+ * Ensure numBins is always greater than the categories. For multiclass classification,
+ * numBins should be greater than 2^(maxCategories - 1) - 1.
* It's a limitation of the current implementation but a reasonable trade-off since features
* with large number of categories get favored over continuous features.
+ *
+ * This needs to be checked here instead of in Strategy since numBins can be determined
+ * by the number of training examples.
+ * TODO: Allow this case, where we simply will know nothing about some categories.
*/
if (strategy.categoricalFeaturesInfo.size > 0) {
val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 4ee4bcd0bc..f31a503608 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.configuration
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.tree.impurity.Impurity
+import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
@@ -90,4 +90,33 @@ class Strategy (
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
}
+ private[tree] def assertValid(): Unit = {
+ algo match {
+ case Classification =>
+ require(numClassesForClassification >= 2,
+ s"DecisionTree Strategy for Classification must have numClassesForClassification >= 2," +
+ s" but numClassesForClassification = $numClassesForClassification.")
+ require(Set(Gini, Entropy).contains(impurity),
+ s"DecisionTree Strategy given invalid impurity for Classification: $impurity." +
+ s" Valid settings: Gini, Entropy")
+ case Regression =>
+ require(impurity == Variance,
+ s"DecisionTree Strategy given invalid impurity for Regression: $impurity." +
+ s" Valid settings: Variance")
+ case _ =>
+ throw new IllegalArgumentException(
+ s"DecisionTree Strategy given invalid algo parameter: $algo." +
+ s" Valid settings are: Classification, Regression.")
+ }
+ require(maxDepth >= 0, s"DecisionTree Strategy given invalid maxDepth parameter: $maxDepth." +
+ s" Valid values are integers >= 0.")
+ require(maxBins >= 2, s"DecisionTree Strategy given invalid maxBins parameter: $maxBins." +
+ s" Valid values are integers >= 2.")
+ categoricalFeaturesInfo.foreach { case (feature, arity) =>
+ require(arity >= 2,
+ s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" +
+ s" feature $feature has $arity categories. The number of categories should be >= 2.")
+ }
+ }
+
}