aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-09-02 21:19:42 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-02 21:19:42 -0700
commit44948a2e9dcad5cd8d1eb749f469e49c5750b5ba (patch)
tree71ec76373b3f4c45f6c49b146acd19ae8d850554 /mllib
parent03f3e91ff21707d8a1c7057a00f1b1cd8b743e3f (diff)
downloadspark-44948a2e9dcad5cd8d1eb749f469e49c5750b5ba.tar.gz
spark-44948a2e9dcad5cd8d1eb749f469e49c5750b5ba.tar.bz2
spark-44948a2e9dcad5cd8d1eb749f469e49c5750b5ba.zip
[SPARK-9723] [ML] params getordefault should throw more useful error
Params.getOrDefault should throw a more meaningful exception than what you get from a bad key lookup. Author: Holden Karau <holden@pigscanfly.ca> Closes #8567 from holdenk/SPARK-9723-params-getordefault-should-throw-more-useful-error.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala15
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala5
3 files changed, 15 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 91c0a56313..de32b7218c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -461,7 +461,8 @@ trait Params extends Identifiable with Serializable {
*/
final def getOrDefault[T](param: Param[T]): T = {
shouldOwn(param)
- get(param).orElse(getDefault(param)).get
+ get(param).orElse(getDefault(param)).getOrElse(
+ throw new NoSuchElementException(s"Failed to find a default value for ${param.name}"))
}
/** An alias for [[getOrDefault()]]. */
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 2c878f8372..dfab82c8b6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -40,6 +40,10 @@ class ParamsSuite extends SparkFunSuite {
assert(inputCol.toString === s"${uid}__inputCol")
+ intercept[java.util.NoSuchElementException] {
+ solver.getOrDefault(solver.handleInvalid)
+ }
+
intercept[IllegalArgumentException] {
solver.setMaxIter(-1)
}
@@ -102,12 +106,13 @@ class ParamsSuite extends SparkFunSuite {
test("params") {
val solver = new TestParams()
- import solver.{maxIter, inputCol}
+ import solver.{handleInvalid, maxIter, inputCol}
val params = solver.params
- assert(params.length === 2)
- assert(params(0).eq(inputCol), "params must be ordered by name")
- assert(params(1).eq(maxIter))
+ assert(params.length === 3)
+ assert(params(0).eq(handleInvalid), "params must be ordered by name")
+ assert(params(1).eq(inputCol), "params must be ordered by name")
+ assert(params(2).eq(maxIter))
assert(!solver.isSet(maxIter))
assert(solver.isDefined(maxIter))
@@ -122,7 +127,7 @@ class ParamsSuite extends SparkFunSuite {
assert(solver.explainParam(maxIter) ===
"maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)")
assert(solver.explainParams() ===
- Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n"))
+ Seq(handleInvalid, inputCol, maxIter).map(solver.explainParam).mkString("\n"))
assert(solver.getParam("inputCol").eq(inputCol))
assert(solver.getParam("maxIter").eq(maxIter))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index 2759248344..9d23547f28 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -17,11 +17,12 @@
package org.apache.spark.ml.param
-import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter}
+import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasMaxIter}
import org.apache.spark.ml.util.Identifiable
/** A subclass of Params for testing. */
-class TestParams(override val uid: String) extends Params with HasMaxIter with HasInputCol {
+class TestParams(override val uid: String) extends Params with HasHandleInvalid with HasMaxIter
+ with HasInputCol {
def this() = this(Identifiable.randomUID("testParams"))