diff options
Diffstat (limited to 'mllib')
3 files changed, 15 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 91c0a56313..de32b7218c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -461,7 +461,8 @@ trait Params extends Identifiable with Serializable { */ final def getOrDefault[T](param: Param[T]): T = { shouldOwn(param) - get(param).orElse(getDefault(param)).get + get(param).orElse(getDefault(param)).getOrElse( + throw new NoSuchElementException(s"Failed to find a default value for ${param.name}")) } /** An alias for [[getOrDefault()]]. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 2c878f8372..dfab82c8b6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -40,6 +40,10 @@ class ParamsSuite extends SparkFunSuite { assert(inputCol.toString === s"${uid}__inputCol") + intercept[java.util.NoSuchElementException] { + solver.getOrDefault(solver.handleInvalid) + } + intercept[IllegalArgumentException] { solver.setMaxIter(-1) } @@ -102,12 +106,13 @@ class ParamsSuite extends SparkFunSuite { test("params") { val solver = new TestParams() - import solver.{maxIter, inputCol} + import solver.{handleInvalid, maxIter, inputCol} val params = solver.params - assert(params.length === 2) - assert(params(0).eq(inputCol), "params must be ordered by name") - assert(params(1).eq(maxIter)) + assert(params.length === 3) + assert(params(0).eq(handleInvalid), "params must be ordered by name") + assert(params(1).eq(inputCol), "params must be ordered by name") + assert(params(2).eq(maxIter)) assert(!solver.isSet(maxIter)) assert(solver.isDefined(maxIter)) @@ -122,7 +127,7 @@ class ParamsSuite extends SparkFunSuite { assert(solver.explainParam(maxIter) === "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)") assert(solver.explainParams() === - Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n")) + Seq(handleInvalid, inputCol, maxIter).map(solver.explainParam).mkString("\n")) assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 2759248344..9d23547f28 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -17,11 +17,12 @@ package org.apache.spark.ml.param -import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasMaxIter} import org.apache.spark.ml.util.Identifiable /** A subclass of Params for testing. */ -class TestParams(override val uid: String) extends Params with HasMaxIter with HasInputCol { +class TestParams(override val uid: String) extends Params with HasHandleInvalid with HasMaxIter + with HasInputCol { def this() = this(Identifiable.randomUID("testParams")) |