aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tuning.py
diff options
context:
space:
mode:
authornoelsmith <mail@noelsmith.com>2015-08-27 23:59:30 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-27 23:59:30 -0700
commit7583681e6b0824d7eed471dc4d8fa0b2addf9ffc (patch)
treed8cfb483586f9aa024fc328aaf515048426b664e /python/pyspark/ml/tuning.py
parent89b943438512fcfb239c268b43431397de46cbcf (diff)
downloadspark-7583681e6b0824d7eed471dc4d8fa0b2addf9ffc.tar.gz
spark-7583681e6b0824d7eed471dc4d8fa0b2addf9ffc.tar.bz2
spark-7583681e6b0824d7eed471dc4d8fa0b2addf9ffc.zip
[SPARK-10188] [PYSPARK] Pyspark CrossValidator with RMSE selects incorrect model
* Added isLargerBetter() method to Pyspark Evaluator to match the Scala version. * JavaEvaluator delegates isLargerBetter() to underlying Scala object. * Added check for isLargerBetter() in CrossValidator to determine whether to use argmin or argmax. * Added test cases for where smaller is better (RMSE) and larger is better (R-Squared). (This contribution is my original work and that I license the work to the project under Sparks' open source license) Author: noelsmith <mail@noelsmith.com> Closes #8399 from noel-smith/pyspark-rmse-xval-fix.
Diffstat (limited to 'python/pyspark/ml/tuning.py')
-rw-r--r--python/pyspark/ml/tuning.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index dcfee6a317..cae778869e 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -223,7 +223,11 @@ class CrossValidator(Estimator):
# TODO: duplicate evaluator to take extra params from input
metric = eva.evaluate(model.transform(validation, epm[j]))
metrics[j] += metric
- bestIndex = np.argmax(metrics)
+
+ if eva.isLargerBetter():
+ bestIndex = np.argmax(metrics)
+ else:
+ bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
return CrossValidatorModel(bestModel)