aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-08-04 10:12:22 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-04 10:12:22 -0700
commit5a23213c148bfe362514f9c71f5273ebda0a848a (patch)
tree1e2646c72d94b36387581ee8b5d99e14305fe650 /mllib/src/test
parent34a0eb2e89d59b0823efc035ddf2dc93f19540c1 (diff)
downloadspark-5a23213c148bfe362514f9c71f5273ebda0a848a.tar.gz
spark-5a23213c148bfe362514f9c71f5273ebda0a848a.tar.bz2
spark-5a23213c148bfe362514f9c71f5273ebda0a848a.zip
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification. Note that the primary author of this PR is holdenk Author: Holden Karau <holden@pigscanfly.ca> Author: Joseph K. Bradley <joseph@databricks.com> Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits: 3952977 [Joseph K. Bradley] fixed pyspark doc test 85febc8 [Joseph K. Bradley] made python unit tests a little more robust 7eb1d86 [Joseph K. Bradley] small cleanups 6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues. 0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests 7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc. 6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests 25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression c02d6c0 [Holden Karau] No default for thresholds 5e43628 [Holden Karau] CR feedback and fixed the renamed test f3fbbd1 [Holden Karau] revert the changes to random forest :( 51f581c [Holden Karau] Add explicit types to public methods, fix long line f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic 398078a [Holden Karau] move the thresholding around a bunch based on the design doc 4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok) 638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test e09919c [Holden Karau] Fix return type, I need more coffee.... 8d92cac [Holden Karau] Use ClassifierParams as the head 3456ed3 [Holden Karau] Add explicit return types even though just test a0f3b0c [Holden Karau] scala style fixes 6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now ffc8dab [Holden Karau] Update the sharedParams 0420290 [Holden Karau] Allow us to override the get methods selectively 978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions 1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there" 1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there efb9084 [Holden Karau] move setThresholds only to where its used 6b34809 [Holden Karau] Add a test with thresholding for the RFCS 74f54c3 [Holden Karau] Fix creation of vote array 1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down. 2f44b18 [Holden Karau] Add a global default of null for thresholds param f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds" 634b06f [Holden Karau] Some progress towards unifying threshold and thresholds 85c9e01 [Holden Karau] Test passes again... little fnur 099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer) 0f46836 [Holden Karau] Start adding a classifiersuite f70eb5e [Holden Karau] Fix test compile issues a7d59c8 [Holden Karau] Move thresholding into Classifier trait 5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test) 1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation 31d6bf2 [Holden Karau] Start threading the threshold info through 0ef228c [Holden Karau] Add hasthresholds
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala28
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala57
4 files changed, 91 insertions, 5 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index f75e024a71..fb1de51163 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -87,6 +87,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
LogisticRegression parent = (LogisticRegression) model.parent();
assert(parent.getMaxIter() == 10);
assert(parent.getRegParam() == 1.0);
+ assert(parent.getThresholds()[0] == 0.4);
+ assert(parent.getThresholds()[1] == 0.6);
assert(parent.getThreshold() == 0.6);
assert(model.getThreshold() == 0.6);
@@ -98,7 +100,9 @@ public class JavaLogisticRegressionSuite implements Serializable {
assert(r.getDouble(0) == 0.0);
}
// Call transform with params, and check that the params worked.
- model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
+ double[] thresholds = {1.0, 0.0};
+ model.transform(
+ dataset, model.thresholds().w(thresholds), model.probabilityCol().w("myProb"))
.registerTempTable("predNotAllZero");
DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
boolean foundNonZero = false;
@@ -108,8 +112,9 @@ public class JavaLogisticRegressionSuite implements Serializable {
assert(foundNonZero);
// Call fit() with new params, and check as many params as we can.
+ double[] thresholds2 = {0.6, 0.4};
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
- lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
+ lr.thresholds().w(thresholds2), lr.probabilityCol().w("theProb"));
LogisticRegression parent2 = (LogisticRegression) model2.parent();
assert(parent2.getMaxIter() == 5);
assert(parent2.getRegParam() == 0.1);
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index b7dd447538..da13dcb42d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -91,6 +91,28 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.hasParent)
}
+ test("setThreshold, getThreshold") {
+ val lr = new LogisticRegression
+ // default
+ withClue("LogisticRegression should not have thresholds set by default") {
+ intercept[java.util.NoSuchElementException] {
+ lr.getThresholds
+ }
+ }
+ // Set via thresholds.
+ // Intuition: Large threshold or large thresholds(1) makes class 0 more likely.
+ lr.setThreshold(1.0)
+ assert(lr.getThresholds === Array(0.0, 1.0))
+ lr.setThreshold(0.0)
+ assert(lr.getThresholds === Array(1.0, 0.0))
+ lr.setThreshold(0.5)
+ assert(lr.getThresholds === Array(0.5, 0.5))
+ // Test getThreshold
+ lr.setThresholds(Array(0.3, 0.7))
+ val expectedThreshold = 1.0 / (1.0 + 0.3 / 0.7)
+ assert(lr.getThreshold ~== expectedThreshold relTol 1E-7)
+ }
+
test("logistic regression doesn't fit intercept when fitIntercept is off") {
val lr = new LogisticRegression
lr.setFitIntercept(false)
@@ -123,14 +145,16 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.")
// Call transform with params, and check that the params worked.
val predNotAllZero =
- model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb")
+ model.transform(dataset, model.thresholds -> Array(1.0, 0.0),
+ model.probabilityCol -> "myProb")
.select("prediction", "myProb")
.collect()
.map { case Row(pred: Double, prob: Vector) => pred }
assert(predNotAllZero.exists(_ !== 0.0))
// Call fit() with new params, and check as many params as we can.
- val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4,
+ val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1,
+ lr.thresholds -> Array(0.6, 0.4),
lr.probabilityCol -> "theProb")
val parent2 = model2.parent.asInstanceOf[LogisticRegression]
assert(parent2.getMaxIter === 5)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 3775292f6d..bd8e819f69 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -151,7 +151,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10,
"copy should handle extra classifier params")
- val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1))
+ val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.thresholds -> Array(0.9, 0.1)))
ovrModel.models.foreach { case m: LogisticRegressionModel =>
require(m.getThreshold === 0.1, "copy should handle extra model params")
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
new file mode 100644
index 0000000000..8f50cb924e
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+
+final class TestProbabilisticClassificationModel(
+ override val uid: String,
+ override val numClasses: Int)
+ extends ProbabilisticClassificationModel[Vector, TestProbabilisticClassificationModel] {
+
+ override def copy(extra: org.apache.spark.ml.param.ParamMap): this.type = defaultCopy(extra)
+
+ override protected def predictRaw(input: Vector): Vector = {
+ input
+ }
+
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction
+ }
+
+ def friendlyPredict(input: Vector): Double = {
+ predict(input)
+ }
+}
+
+
+class ProbabilisticClassifierSuite extends SparkFunSuite {
+
+ test("test thresholding") {
+ val thresholds = Array(0.5, 0.2)
+ val testModel = new TestProbabilisticClassificationModel("myuid", 2).setThresholds(thresholds)
+ assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0)
+ assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0)
+ }
+
+ test("test thresholding not required") {
+ val testModel = new TestProbabilisticClassificationModel("myuid", 2)
+ assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
+ }
+}