aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-08-04 10:12:22 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-04 10:12:22 -0700
commit5a23213c148bfe362514f9c71f5273ebda0a848a (patch)
tree1e2646c72d94b36387581ee8b5d99e14305fe650 /examples
parent34a0eb2e89d59b0823efc035ddf2dc93f19540c1 (diff)
downloadspark-5a23213c148bfe362514f9c71f5273ebda0a848a.tar.gz
spark-5a23213c148bfe362514f9c71f5273ebda0a848a.tar.bz2
spark-5a23213c148bfe362514f9c71f5273ebda0a848a.zip
[SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier
This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification. Note that the primary author of this PR is holdenk Author: Holden Karau <holden@pigscanfly.ca> Author: Joseph K. Bradley <joseph@databricks.com> Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits: 3952977 [Joseph K. Bradley] fixed pyspark doc test 85febc8 [Joseph K. Bradley] made python unit tests a little more robust 7eb1d86 [Joseph K. Bradley] small cleanups 6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues. 0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests 7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc. 6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests 25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression c02d6c0 [Holden Karau] No default for thresholds 5e43628 [Holden Karau] CR feedback and fixed the renamed test f3fbbd1 [Holden Karau] revert the changes to random forest :( 51f581c [Holden Karau] Add explicit types to public methods, fix long line f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic 398078a [Holden Karau] move the thresholding around a bunch based on the design doc 4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok) 638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test e09919c [Holden Karau] Fix return type, I need more coffee.... 8d92cac [Holden Karau] Use ClassifierParams as the head 3456ed3 [Holden Karau] Add explicit return types even though just test a0f3b0c [Holden Karau] scala style fixes 6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now ffc8dab [Holden Karau] Update the sharedParams 0420290 [Holden Karau] Allow us to override the get methods selectively 978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions 1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there" 1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there efb9084 [Holden Karau] move setThresholds only to where its used 6b34809 [Holden Karau] Add a test with thresholding for the RFCS 74f54c3 [Holden Karau] Fix creation of vote array 1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down. 2f44b18 [Holden Karau] Add a global default of null for thresholds param f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds" 634b06f [Holden Karau] Some progress towards unifying threshold and thresholds 85c9e01 [Holden Karau] Test passes again... little fnur 099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer) 0f46836 [Holden Karau] Start adding a classifiersuite f70eb5e [Holden Karau] Fix test compile issues a7d59c8 [Holden Karau] Move thresholding into Classifier trait 5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test) 1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation 31d6bf2 [Holden Karau] Start threading the threshold info through 0ef228c [Holden Karau] Add hasthresholds
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java3
-rw-r--r--examples/src/main/python/ml/simple_params_example.py2
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala2
3 files changed, 4 insertions, 3 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index dac649d1d5..94beeced3d 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -77,7 +77,8 @@ public class JavaSimpleParamsExample {
ParamMap paramMap = new ParamMap();
paramMap.put(lr.maxIter().w(20)); // Specify 1 Param.
paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter.
- paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params.
+ double thresholds[] = {0.45, 0.55};
+ paramMap.put(lr.regParam().w(0.1), lr.thresholds().w(thresholds)); // Specify multiple Params.
// One can also combine ParamMaps.
ParamMap paramMap2 = new ParamMap();
diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py
index a9f29dab2d..2d6d115d54 100644
--- a/examples/src/main/python/ml/simple_params_example.py
+++ b/examples/src/main/python/ml/simple_params_example.py
@@ -70,7 +70,7 @@ if __name__ == "__main__":
# We may alternatively specify parameters using a parameter map.
# paramMap overrides all lr parameters set earlier.
- paramMap = {lr.maxIter: 20, lr.threshold: 0.55, lr.probabilityCol: "myProbability"}
+ paramMap = {lr.maxIter: 20, lr.thresholds: [0.45, 0.55], lr.probabilityCol: "myProbability"}
# Now learn a new model using the new parameters.
model2 = lr.fit(training, paramMap)
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index 58d7b67674..f4d1fe5785 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -70,7 +70,7 @@ object SimpleParamsExample {
// which supports several methods for specifying parameters.
val paramMap = ParamMap(lr.maxIter -> 20)
paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter.
- paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params.
+ paramMap.put(lr.regParam -> 0.1, lr.thresholds -> Array(0.45, 0.55)) // Specify multiple Params.
// One can also combine ParamMaps.
val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name