aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-08-04 10:12:22 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-04 10:12:22 -0700
commit5a23213c148bfe362514f9c71f5273ebda0a848a (patch)
tree1e2646c72d94b36387581ee8b5d99e14305fe650 /mllib/src/main
parent34a0eb2e89d59b0823efc035ddf2dc93f19540c1 (diff)
downloadspark-5a23213c148bfe362514f9c71f5273ebda0a848a.tar.gz
spark-5a23213c148bfe362514f9c71f5273ebda0a848a.tar.bz2
spark-5a23213c148bfe362514f9c71f5273ebda0a848a.zip
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification. Note that the primary author of this PR is holdenk Author: Holden Karau <holden@pigscanfly.ca> Author: Joseph K. Bradley <joseph@databricks.com> Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits: 3952977 [Joseph K. Bradley] fixed pyspark doc test 85febc8 [Joseph K. Bradley] made python unit tests a little more robust 7eb1d86 [Joseph K. Bradley] small cleanups 6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues. 0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests 7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc. 6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests 25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression c02d6c0 [Holden Karau] No default for thresholds 5e43628 [Holden Karau] CR feedback and fixed the renamed test f3fbbd1 [Holden Karau] revert the changes to random forest :( 51f581c [Holden Karau] Add explicit types to public methods, fix long line f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic 398078a [Holden Karau] move the thresholding around a bunch based on the design doc 4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok) 638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test e09919c [Holden Karau] Fix return type, I need more coffee.... 8d92cac [Holden Karau] Use ClassifierParams as the head 3456ed3 [Holden Karau] Add explicit return types even though just test a0f3b0c [Holden Karau] scala style fixes 6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now ffc8dab [Holden Karau] Update the sharedParams 0420290 [Holden Karau] Allow us to override the get methods selectively 978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions 1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there" 1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there efb9084 [Holden Karau] move setThresholds only to where its used 6b34809 [Holden Karau] Add a test with thresholding for the RFCS 74f54c3 [Holden Karau] Fix creation of vote array 1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down. 2f44b18 [Holden Karau] Add a global default of null for thresholds param f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds" 634b06f [Holden Karau] Some progress towards unifying threshold and thresholds 85c9e01 [Holden Karau] Test passes again... little fnur 099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer) 0f46836 [Holden Karau] Start adding a classifiersuite f70eb5e [Holden Karau] Fix test compile issues a7d59c8 [Holden Karau] Move thresholding into Classifier trait 5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test) 1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation 31d6bf2 [Holden Karau] Start threading the threshold info through 0ef228c [Holden Karau] Add hasthresholds
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala47
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala41
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala3
6 files changed, 110 insertions, 20 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 581d8fa774..45df557a89 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -18,14 +18,13 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+import org.apache.spark.sql.types.{DataType, StructType}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 8fc9199fb4..c937b9602b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -41,7 +41,39 @@ import org.apache.spark.storage.StorageLevel
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
- with HasThreshold with HasStandardization
+ with HasStandardization {
+
+ /**
+ * Version of setThresholds() for binary classification, available for backwards
+ * compatibility.
+ *
+ * Calling this with threshold p will effectively call `setThresholds(Array(1-p, p))`.
+ *
+ * Default is effectively 0.5.
+ * @group setParam
+ */
+ def setThreshold(value: Double): this.type = set(thresholds, Array(1.0 - value, value))
+
+ /**
+ * Version of [[getThresholds()]] for binary classification, available for backwards
+ * compatibility.
+ *
+ * Param thresholds must have length 2 (or not be specified).
+ * This returns {{{1 / (1 + thresholds(0) / thresholds(1))}}}.
+ * @group getParam
+ */
+ def getThreshold: Double = {
+ if (isDefined(thresholds)) {
+ val thresholdValues = $(thresholds)
+ assert(thresholdValues.length == 2, "Logistic Regression getThreshold only applies to" +
+ " binary classification, but thresholds has length != 2." +
+ s" thresholds: ${thresholdValues.mkString(",")}")
+ 1.0 / (1.0 + thresholdValues(0) / thresholdValues(1))
+ } else {
+ 0.5
+ }
+ }
+}
/**
* :: Experimental ::
@@ -110,9 +142,9 @@ class LogisticRegression(override val uid: String)
def setStandardization(value: Boolean): this.type = set(standardization, value)
setDefault(standardization -> true)
- /** @group setParam */
- def setThreshold(value: Double): this.type = set(threshold, value)
- setDefault(threshold -> 0.5)
+ override def setThreshold(value: Double): this.type = super.setThreshold(value)
+
+ override def getThreshold: Double = super.getThreshold
override protected def train(dataset: DataFrame): LogisticRegressionModel = {
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
@@ -270,8 +302,9 @@ class LogisticRegressionModel private[ml] (
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
with LogisticRegressionParams {
- /** @group setParam */
- def setThreshold(value: Double): this.type = set(threshold, value)
+ override def setThreshold(value: Double): this.type = super.setThreshold(value)
+
+ override def getThreshold: Double = super.getThreshold
/** Margin (rawPrediction) for class label 1. For binary classification only. */
private val margin: Vector => Double = (features) => {
@@ -288,7 +321,7 @@ class LogisticRegressionModel private[ml] (
/**
* Predict label for the given feature vector.
- * The behavior of this can be adjusted using [[threshold]].
+ * The behavior of this can be adjusted using [[thresholds]].
*/
override protected def predict(features: Vector): Double = {
if (score(features) > getThreshold) 1 else 0
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index f9c9c2371f..1e50a895a9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -20,17 +20,16 @@ package org.apache.spark.ml.classification
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
-import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT}
+import org.apache.spark.mllib.linalg.{DenseVector, Vector, VectorUDT, Vectors}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DoubleType, DataType, StructType}
+import org.apache.spark.sql.types.{DataType, StructType}
/**
* (private[classification]) Params for probabilistic classification.
*/
private[classification] trait ProbabilisticClassifierParams
- extends ClassifierParams with HasProbabilityCol {
-
+ extends ClassifierParams with HasProbabilityCol with HasThresholds {
override protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean,
@@ -59,6 +58,9 @@ private[spark] abstract class ProbabilisticClassifier[
/** @group setParam */
def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E]
+
+ /** @group setParam */
+ def setThresholds(value: Array[Double]): E = set(thresholds, value).asInstanceOf[E]
}
@@ -80,6 +82,9 @@ private[spark] abstract class ProbabilisticClassificationModel[
/** @group setParam */
def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M]
+ /** @group setParam */
+ def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M]
+
/**
* Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
* parameters:
@@ -92,6 +97,11 @@ private[spark] abstract class ProbabilisticClassificationModel[
*/
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
+ if (isDefined(thresholds)) {
+ require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+ ".transform() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+ }
// Output selected columns only.
// This is a bit complicated since it tries to avoid repeated computation.
@@ -155,6 +165,14 @@ private[spark] abstract class ProbabilisticClassificationModel[
raw2probabilityInPlace(probs)
}
+ override protected def raw2prediction(rawPrediction: Vector): Double = {
+ if (!isDefined(thresholds)) {
+ rawPrediction.argmax
+ } else {
+ probability2prediction(raw2probability(rawPrediction))
+ }
+ }
+
/**
* Predict the probability of each class given the features.
* These predictions are also called class conditional probabilities.
@@ -170,10 +188,21 @@ private[spark] abstract class ProbabilisticClassificationModel[
/**
* Given a vector of class conditional probabilities, select the predicted label.
- * This may be overridden to support thresholds which favor particular labels.
+ * This supports thresholds which favor particular labels.
* @return predicted label
*/
- protected def probability2prediction(probability: Vector): Double = probability.argmax
+ protected def probability2prediction(probability: Vector): Double = {
+ if (!isDefined(thresholds)) {
+ probability.argmax
+ } else {
+ val thresholds: Array[Double] = getThresholds
+ val scaledProbability: Array[Double] =
+ probability.toArray.zip(thresholds).map { case (p, t) =>
+ if (t == 0.0) Double.PositiveInfinity else p / t
+ }
+ Vectors.dense(scaledProbability).argmax
+ }
+ }
}
private[ml] object ProbabilisticClassificationModel {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index f7ae1de522..a97c8059b8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -46,7 +46,13 @@ private[shared] object SharedParamsCodeGen {
Some("\"probability\"")),
ParamDesc[Double]("threshold",
"threshold in binary classification prediction, in range [0, 1]",
- isValid = "ParamValidators.inRange(0, 1)"),
+ isValid = "ParamValidators.inRange(0, 1)", finalMethods = false),
+ ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" +
+ " to adjust the probability of predicting each class." +
+ " Array must have length equal to the number of classes, with values >= 0." +
+ " The class with largest value p/t is predicted, where p is the original probability" +
+ " of that class and t is the class' threshold.",
+ isValid = "(t: Array[Double]) => t.forall(_ >= 0)"),
ParamDesc[String]("inputCol", "input column name"),
ParamDesc[Array[String]]("inputCols", "input column names"),
ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
@@ -74,7 +80,8 @@ private[shared] object SharedParamsCodeGen {
name: String,
doc: String,
defaultValueStr: Option[String] = None,
- isValid: String = "") {
+ isValid: String = "",
+ finalMethods: Boolean = true) {
require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
require(doc.nonEmpty) // TODO: more rigorous on doc
@@ -88,6 +95,7 @@ private[shared] object SharedParamsCodeGen {
case _ if c == classOf[Double] => "DoubleParam"
case _ if c == classOf[Boolean] => "BooleanParam"
case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam"
+ case _ if c.isArray && c.getComponentType == classOf[Double] => s"DoubleArrayParam"
case _ => s"Param[${getTypeString(c)}]"
}
}
@@ -131,6 +139,11 @@ private[shared] object SharedParamsCodeGen {
} else {
""
}
+ val methodStr = if (param.finalMethods) {
+ "final def"
+ } else {
+ "def"
+ }
s"""
|/**
@@ -145,7 +158,7 @@ private[shared] object SharedParamsCodeGen {
| final val $name: $Param = new $Param(this, "$name", "$doc"$isValid)
|$setDefault
| /** @group getParam */
- | final def get$Name: $T = $$($name)
+ | $methodStr get$Name: $T = $$($name)
|}
|""".stripMargin
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 65e48e4ee5..f332630c32 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -150,7 +150,22 @@ private[ml] trait HasThreshold extends Params {
final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1))
/** @group getParam */
- final def getThreshold: Double = $(threshold)
+ def getThreshold: Double = $(threshold)
+}
+
+/**
+ * Trait for shared param thresholds.
+ */
+private[ml] trait HasThresholds extends Params {
+
+ /**
+ * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold..
+ * @group param
+ */
+ final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", (t: Array[Double]) => t.forall(_ >= 0))
+
+ /** @group getParam */
+ final def getThresholds: Array[Double] = $(thresholds)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index a0c5238d96..e817090f8a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -17,9 +17,10 @@
package org.apache.spark.ml.tree
+import org.apache.spark.ml.classification.ClassifierParams
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed}
+import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasThresholds}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}