diff options
author | Yong Tang <yong.tang.github@outlook.com> | 2016-04-14 17:23:16 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-04-14 17:23:16 -0700 |
commit | 01dd1f5c07f5c9ba91389c1556f911b028475cd3 (patch) | |
tree | 62e3e960564366c5226e49d0a7e29cd2d8c58735 /mllib | |
parent | d7e124edfe2578ecdf8e816a4dda3ce430a09172 (diff) | |
download | spark-01dd1f5c07f5c9ba91389c1556f911b028475cd3.tar.gz spark-01dd1f5c07f5c9ba91389c1556f911b028475cd3.tar.bz2 spark-01dd1f5c07f5c9ba91389c1556f911b028475cd3.zip |
[SPARK-14565][ML] RandomForest should use parseInt and parseDouble for feature subset size instead of regexes
## What changes were proposed in this pull request?
This fix tries to change RandomForest's supported strategies from using regexes to using parseInt and
parseDouble, for the purpose of robustness and maintainability.
## How was this patch tested?
Existing tests passed.
Author: Yong Tang <yong.tang.github@outlook.com>
Closes #12360 from yongtang/SPARK-14565.
Diffstat (limited to 'mllib')
4 files changed, 25 insertions, 13 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index c7cde1563f..5f7c40f607 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -18,8 +18,10 @@ package org.apache.spark.ml.tree.impl import scala.collection.mutable +import scala.util.Try import org.apache.spark.internal.Logging +import org.apache.spark.ml.tree.RandomForestParams import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ @@ -184,15 +186,22 @@ private[spark] object DecisionTreeMetadata extends Logging { case _ => featureSubsetStrategy } - val isIntRegex = "^([1-9]\\d*)$".r - val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r val numFeaturesPerNode: Int = _featureSubsetStrategy match { case "all" => numFeatures case "sqrt" => math.sqrt(numFeatures).ceil.toInt case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) case "onethird" => (numFeatures / 3.0).ceil.toInt - case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures else number.toInt - case isFractionRegex(fraction) => (fraction.toDouble * numFeatures).ceil.toInt + case _ => + Try(_featureSubsetStrategy.toInt).filter(_ > 0).toOption match { + case Some(value) => math.min(value, numFeatures) + case None => + Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match { + case Some(value) => math.ceil(value * numFeatures).toInt + case _ => throw new IllegalArgumentException(s"Supported values:" + + s" ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}," + + s" (0.0-1.0], [1-n].") + } + } } new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index b6783911ad..d7559f8950 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tree +import scala.util.Try + import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -346,10 +348,12 @@ private[ml] trait HasFeatureSubsetStrategy extends Params { */ final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy", "The number of features to consider for splits at each tree node." + - s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}", + s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" + + s", (0.0-1.0], [1-n].", (value: String) => RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase) - || value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex)) + || Try(value.toInt).filter(_ > 0).isSuccess + || Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess) setDefault(featureSubsetStrategy -> "auto") @@ -396,9 +400,6 @@ private[spark] object RandomForestParams { // These options should be lowercase. final val supportedFeatureSubsetStrategies: Array[String] = Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) - - // The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features) - final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$" } private[ml] trait RandomForestClassifierParams diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 26755849ad..ca7fb7f51c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ +import scala.util.Try import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD @@ -76,9 +77,10 @@ private class RandomForest ( strategy.assertValid() require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.") require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy) - || featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex), + || Try(featureSubsetStrategy.toInt).filter(_ > 0).isSuccess + || Try(featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess, s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." + - s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," + + s" Supported values: ${NewRFParams.supportedFeatureSubsetStrategies.mkString(", ")}," + s" (0.0-1.0], [1-n].") /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 6db9ce150d..1719f9fab5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -440,7 +440,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0") for (invalidStrategy <- invalidStrategies) { - intercept[MatchError]{ + intercept[IllegalArgumentException]{ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy) } @@ -463,7 +463,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) } for (invalidStrategy <- invalidStrategies) { - intercept[MatchError]{ + intercept[IllegalArgumentException]{ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy) } |