diff options
author | noelsmith <mail@noelsmith.com> | 2015-08-27 23:59:30 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-08-27 23:59:30 -0700 |
commit | 7583681e6b0824d7eed471dc4d8fa0b2addf9ffc (patch) | |
tree | d8cfb483586f9aa024fc328aaf515048426b664e /python/pyspark/ml/tuning.py | |
parent | 89b943438512fcfb239c268b43431397de46cbcf (diff) | |
download | spark-7583681e6b0824d7eed471dc4d8fa0b2addf9ffc.tar.gz spark-7583681e6b0824d7eed471dc4d8fa0b2addf9ffc.tar.bz2 spark-7583681e6b0824d7eed471dc4d8fa0b2addf9ffc.zip |
[SPARK-10188] [PYSPARK] Pyspark CrossValidator with RMSE selects incorrect model
* Added isLargerBetter() method to Pyspark Evaluator to match the Scala version.
* JavaEvaluator delegates isLargerBetter() to underlying Scala object.
* Added check for isLargerBetter() in CrossValidator to determine whether to use argmin or argmax.
* Added test cases for where smaller is better (RMSE) and larger is better (R-Squared).
(This contribution is my original work and that I license the work to the project under Sparks' open source license)
Author: noelsmith <mail@noelsmith.com>
Closes #8399 from noel-smith/pyspark-rmse-xval-fix.
Diffstat (limited to 'python/pyspark/ml/tuning.py')
-rw-r--r-- | python/pyspark/ml/tuning.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index dcfee6a317..cae778869e 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -223,7 +223,11 @@ class CrossValidator(Estimator): # TODO: duplicate evaluator to take extra params from input metric = eva.evaluate(model.transform(validation, epm[j])) metrics[j] += metric - bestIndex = np.argmax(metrics) + + if eva.isLargerBetter(): + bestIndex = np.argmax(metrics) + else: + bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) return CrossValidatorModel(bestModel) |