From 44948a2e9dcad5cd8d1eb749f469e49c5750b5ba Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 2 Sep 2015 21:19:42 -0700 Subject: [SPARK-9723] [ML] params getordefault should throw more useful error Params.getOrDefault should throw a more meaningful exception than what you get from a bad key lookup. Author: Holden Karau Closes #8567 from holdenk/SPARK-9723-params-getordefault-should-throw-more-useful-error. --- .../src/main/scala/org/apache/spark/ml/param/params.scala | 3 ++- .../scala/org/apache/spark/ml/param/ParamsSuite.scala | 15 ++++++++++----- .../test/scala/org/apache/spark/ml/param/TestParams.scala | 5 +++-- 3 files changed, 15 insertions(+), 8 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 91c0a56313..de32b7218c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -461,7 +461,8 @@ trait Params extends Identifiable with Serializable { */ final def getOrDefault[T](param: Param[T]): T = { shouldOwn(param) - get(param).orElse(getDefault(param)).get + get(param).orElse(getDefault(param)).getOrElse( + throw new NoSuchElementException(s"Failed to find a default value for ${param.name}")) } /** An alias for [[getOrDefault()]]. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 2c878f8372..dfab82c8b6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -40,6 +40,10 @@ class ParamsSuite extends SparkFunSuite { assert(inputCol.toString === s"${uid}__inputCol") + intercept[java.util.NoSuchElementException] { + solver.getOrDefault(solver.handleInvalid) + } + intercept[IllegalArgumentException] { solver.setMaxIter(-1) } @@ -102,12 +106,13 @@ class ParamsSuite extends SparkFunSuite { test("params") { val solver = new TestParams() - import solver.{maxIter, inputCol} + import solver.{handleInvalid, maxIter, inputCol} val params = solver.params - assert(params.length === 2) - assert(params(0).eq(inputCol), "params must be ordered by name") - assert(params(1).eq(maxIter)) + assert(params.length === 3) + assert(params(0).eq(handleInvalid), "params must be ordered by name") + assert(params(1).eq(inputCol), "params must be ordered by name") + assert(params(2).eq(maxIter)) assert(!solver.isSet(maxIter)) assert(solver.isDefined(maxIter)) @@ -122,7 +127,7 @@ class ParamsSuite extends SparkFunSuite { assert(solver.explainParam(maxIter) === "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)") assert(solver.explainParams() === - Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n")) + Seq(handleInvalid, inputCol, maxIter).map(solver.explainParam).mkString("\n")) assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 2759248344..9d23547f28 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -17,11 +17,12 @@ package org.apache.spark.ml.param -import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasMaxIter} import org.apache.spark.ml.util.Identifiable /** A subclass of Params for testing. */ -class TestParams(override val uid: String) extends Params with HasMaxIter with HasInputCol { +class TestParams(override val uid: String) extends Params with HasHandleInvalid with HasMaxIter + with HasInputCol { def this() = this(Identifiable.randomUID("testParams")) -- cgit v1.2.3