aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-04-13 21:18:05 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-13 21:18:05 -0700
commit971b95b0c9002bd541bcbe0da54a9967ba22588f (patch)
treeb2a79cf00c1d2290e7e4024df27c0ee9b203c09a /mllib/src/test
parent0ba3fdd5992cf09bd38303ebff34d2ed19e5e09b (diff)
downloadspark-971b95b0c9002bd541bcbe0da54a9967ba22588f.tar.gz
spark-971b95b0c9002bd541bcbe0da54a9967ba22588f.tar.bz2
spark-971b95b0c9002bd541bcbe0da54a9967ba22588f.zip
[SPARK-5957][ML] better handling of parameters
The design doc was posted on the JIRA page. Python changes will be in a follow-up PR. jkbradley 1. Use codegen for shared params. 1. Move shared params to package `ml.param.shared`. 1. Set default values in `Params` instead of in `Param`. 1. Add a few methods to `Params` and `ParamMap`. 1. Move schema handling to `SchemaUtils` from `Params`. - [x] check visibility of the methods added Author: Xiangrui Meng <meng@databricks.com> Closes #5431 from mengxr/SPARK-5957 and squashes the following commits: d19236d [Xiangrui Meng] fix test 26ae2d7 [Xiangrui Meng] re-gen code and mark clear protected 38b78c7 [Xiangrui Meng] update Param.toString and remove Params.explain() 409e2d5 [Xiangrui Meng] address comments 2d637bd [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5957 eec2264 [Xiangrui Meng] make get* public in Params 4090d95 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5957 4fee9e7 [Xiangrui Meng] re-gen shared params 2737c2d [Xiangrui Meng] rename SharedParamCodeGen to SharedParamsCodeGen e938f81 [Xiangrui Meng] update code to set default parameter values 28ed322 [Xiangrui Meng] merge master 55be1f3 [Xiangrui Meng] merge master d63b5cc [Xiangrui Meng] fix examples 29b004c [Xiangrui Meng] update ParamsSuite 94fd98e [Xiangrui Meng] fix explain params 48d0e84 [Xiangrui Meng] add remove and update explainParams 4ac6348 [Xiangrui Meng] move schema utils to SchemaUtils add a few methods to Params 0d9594e [Xiangrui Meng] add getOrElse to ParamMap eeeffe8 [Xiangrui Meng] map ++ paramMap => extractValues 0d3fc5b [Xiangrui Meng] setDefault after param a9dbf59 [Xiangrui Meng] minor updates d9302b8 [Xiangrui Meng] generate default values 1c72579 [Xiangrui Meng] pass test compile abb7a3b [Xiangrui Meng] update default values handling dcab97a [Xiangrui Meng] add codegen for shared params
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala47
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala12
2 files changed, 45 insertions, 14 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 1ce2987612..88ea679eea 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -21,19 +21,25 @@ import org.scalatest.FunSuite
class ParamsSuite extends FunSuite {
- val solver = new TestParams()
- import solver.{inputCol, maxIter}
-
test("param") {
+ val solver = new TestParams()
+ import solver.{maxIter, inputCol}
+
assert(maxIter.name === "maxIter")
assert(maxIter.doc === "max number of iterations")
- assert(maxIter.defaultValue.get === 100)
assert(maxIter.parent.eq(solver))
- assert(maxIter.toString === "maxIter: max number of iterations (default: 100)")
- assert(inputCol.defaultValue === None)
+ assert(maxIter.toString === "maxIter: max number of iterations (default: 10)")
+
+ solver.setMaxIter(5)
+ assert(maxIter.toString === "maxIter: max number of iterations (default: 10, current: 5)")
+
+ assert(inputCol.toString === "inputCol: input column name (undefined)")
}
test("param pair") {
+ val solver = new TestParams()
+ import solver.maxIter
+
val pair0 = maxIter -> 5
val pair1 = maxIter.w(5)
val pair2 = ParamPair(maxIter, 5)
@@ -44,10 +50,12 @@ class ParamsSuite extends FunSuite {
}
test("param map") {
+ val solver = new TestParams()
+ import solver.{maxIter, inputCol}
+
val map0 = ParamMap.empty
assert(!map0.contains(maxIter))
- assert(map0(maxIter) === maxIter.defaultValue.get)
map0.put(maxIter, 10)
assert(map0.contains(maxIter))
assert(map0(maxIter) === 10)
@@ -78,23 +86,39 @@ class ParamsSuite extends FunSuite {
}
test("params") {
+ val solver = new TestParams()
+ import solver.{maxIter, inputCol}
+
val params = solver.params
- assert(params.size === 2)
+ assert(params.length === 2)
assert(params(0).eq(inputCol), "params must be ordered by name")
assert(params(1).eq(maxIter))
+
+ assert(!solver.isSet(maxIter))
+ assert(solver.isDefined(maxIter))
+ assert(solver.getMaxIter === 10)
+ solver.setMaxIter(100)
+ assert(solver.isSet(maxIter))
+ assert(solver.getMaxIter === 100)
+ assert(!solver.isSet(inputCol))
+ assert(!solver.isDefined(inputCol))
+ intercept[NoSuchElementException](solver.getInputCol)
+
assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n"))
+
assert(solver.getParam("inputCol").eq(inputCol))
assert(solver.getParam("maxIter").eq(maxIter))
- intercept[NoSuchMethodException] {
+ intercept[NoSuchElementException] {
solver.getParam("abc")
}
- assert(!solver.isSet(inputCol))
+
intercept[IllegalArgumentException] {
solver.validate()
}
solver.validate(ParamMap(inputCol -> "input"))
solver.setInputCol("input")
assert(solver.isSet(inputCol))
+ assert(solver.isDefined(inputCol))
assert(solver.getInputCol === "input")
solver.validate()
intercept[IllegalArgumentException] {
@@ -104,5 +128,8 @@ class ParamsSuite extends FunSuite {
intercept[IllegalArgumentException] {
solver.validate()
}
+
+ solver.clearMaxIter()
+ assert(!solver.isSet(maxIter))
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index ce52f2f230..8f9ab687c0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -20,17 +20,21 @@ package org.apache.spark.ml.param
/** A subclass of Params for testing. */
class TestParams extends Params {
- val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100))
+ val maxIter = new IntParam(this, "maxIter", "max number of iterations")
def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
- def getMaxIter: Int = get(maxIter)
+ def getMaxIter: Int = getOrDefault(maxIter)
val inputCol = new Param[String](this, "inputCol", "input column name")
def setInputCol(value: String): this.type = { set(inputCol, value); this }
- def getInputCol: String = get(inputCol)
+ def getInputCol: String = getOrDefault(inputCol)
+
+ setDefault(maxIter -> 10)
override def validate(paramMap: ParamMap): Unit = {
- val m = this.paramMap ++ paramMap
+ val m = extractParamMap(paramMap)
require(m(maxIter) >= 0)
require(m.contains(inputCol))
}
+
+ def clearMaxIter(): this.type = clear(maxIter)
}