aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-08-14 10:48:02 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-14 10:48:02 -0700
commit7ecf0c46990c39df8aeddbd64ca33d01824bcc0a (patch)
treeb239fcb70b03a084323c76dd34c299107cd43181 /mllib
parenta0e1abbd010b9e73d472ce12ff1d987678005d32 (diff)
downloadspark-7ecf0c46990c39df8aeddbd64ca33d01824bcc0a.tar.gz
spark-7ecf0c46990c39df8aeddbd64ca33d01824bcc0a.tar.bz2
spark-7ecf0c46990c39df8aeddbd64ca33d01824bcc0a.zip
[SPARK-9956] [ML] Make trees work with one-category features
This modifies DecisionTreeMetadata construction to treat 1-category features as continuous, so that trees do not fail with such features. It is important for the pipelines API, where VectorIndexer can automatically categorize certain features as categorical. As stated in the JIRA, this is a temp fix which we can improve upon later by automatically filtering out those features. That will take longer, though, since it will require careful indexing. Targeted for 1.5 and master CC: manishamde mengxr yanboliang Author: Joseph K. Bradley <joseph@databricks.com> Closes #8187 from jkbradley/tree-1cat.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala27
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala13
2 files changed, 30 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 9fe264656e..21ee49c457 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -144,21 +144,28 @@ private[spark] object DecisionTreeMetadata extends Logging {
val maxCategoriesForUnorderedFeature =
((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
- // Decide if some categorical features should be treated as unordered features,
- // which require 2 * ((1 << numCategories - 1) - 1) bins.
- // We do this check with log values to prevent overflows in case numCategories is large.
- // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
- if (numCategories <= maxCategoriesForUnorderedFeature) {
- unorderedFeatures.add(featureIndex)
- numBins(featureIndex) = numUnorderedBins(numCategories)
- } else {
- numBins(featureIndex) = numCategories
+ // Hack: If a categorical feature has only 1 category, we treat it as continuous.
+ // TODO(SPARK-9957): Handle this properly by filtering out those features.
+ if (numCategories > 1) {
+ // Decide if some categorical features should be treated as unordered features,
+ // which require 2 * ((1 << numCategories - 1) - 1) bins.
+ // We do this check with log values to prevent overflows in case numCategories is large.
+ // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
+ if (numCategories <= maxCategoriesForUnorderedFeature) {
+ unorderedFeatures.add(featureIndex)
+ numBins(featureIndex) = numUnorderedBins(numCategories)
+ } else {
+ numBins(featureIndex) = numCategories
+ }
}
}
} else {
// Binary classification or regression
strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
- numBins(featureIndex) = numCategories
+ // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957
+ if (numCategories > 1) {
+ numBins(featureIndex) = numCategories
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 4b7c5d3f23..f680d8d3c4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -261,6 +261,19 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
}
}
+ test("training with 1-category categorical feature") {
+ val data = sc.parallelize(Seq(
+ LabeledPoint(0, Vectors.dense(0, 2, 3)),
+ LabeledPoint(1, Vectors.dense(0, 3, 1)),
+ LabeledPoint(0, Vectors.dense(0, 2, 2)),
+ LabeledPoint(1, Vectors.dense(0, 3, 9)),
+ LabeledPoint(0, Vectors.dense(0, 2, 6))
+ ))
+ val df = TreeTests.setMetadata(data, Map(0 -> 1), 2)
+ val dt = new DecisionTreeClassifier().setMaxDepth(3)
+ val model = dt.fit(df)
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////