aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-11-06 14:51:03 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-11-06 14:51:03 -0800
commitc447c9d54603890db7399fb80adc9fae40b71f64 (patch)
tree0f8a339ee0b28a00944bea96879600315ab3ef17 /mllib/src/test/scala/org
parent3a652f691b220fada0286f8d0a562c5657973d4d (diff)
downloadspark-c447c9d54603890db7399fb80adc9fae40b71f64.tar.gz
spark-c447c9d54603890db7399fb80adc9fae40b71f64.tar.bz2
spark-c447c9d54603890db7399fb80adc9fae40b71f64.zip
[SPARK-11217][ML] save/load for non-meta estimators and transformers
This PR implements the default save/load for non-meta estimators and transformers using the JSON serialization of param values. The saved metadata includes: * class name * uid * timestamp * paramMap The save/load interface is similar to DataFrames. We use the current active context by default, which should be sufficient for most use cases. ~~~scala instance.save("path") instance.write.context(sqlContext).overwrite().save("path") Instance.load("path") ~~~ The param handling is different from the design doc. We didn't save default and user-set params separately, and when we load it back, all parameters are user-set. This does cause issues. But it also cause other issues if we modify the default params. TODOs: * [x] Java test * [ ] a follow-up PR to implement default save/load for all non-meta estimators and transformers cc jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #9454 from mengxr/SPARK-11217.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala110
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala45
3 files changed, 165 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 2086043983..9dfa1439cc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -19,10 +19,11 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
-class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var data: Array[Double] = _
@@ -66,4 +67,12 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(x === y, "The feature value is not correct after binarization.")
}
}
+
+ test("read/write") {
+ val binarizer = new Binarizer()
+ .setInputCol("feature")
+ .setOutputCol("binarized_feature")
+ .setThreshold(0.1)
+ testDefaultReadWrite(binarizer)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
new file mode 100644
index 0000000000..4545b0f281
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.io.{File, IOException}
+
+import org.scalatest.Suite
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
+
+ /**
+ * Checks "overwrite" option and params.
+ * @param instance ML instance to test saving/loading
+ * @tparam T ML instance type
+ */
+ def testDefaultReadWrite[T <: Params with Writable](instance: T): Unit = {
+ val uid = instance.uid
+ val path = new File(tempDir, uid).getPath
+
+ instance.save(path)
+ intercept[IOException] {
+ instance.save(path)
+ }
+ instance.write.overwrite().save(path)
+ val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[Reader[T]]
+ val newInstance = loader.load(path)
+
+ assert(newInstance.uid === instance.uid)
+ instance.params.foreach { p =>
+ if (instance.isDefined(p)) {
+ (instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
+ case (Array(values), Array(newValues)) =>
+ assert(values === newValues, s"Values do not match on param ${p.name}.")
+ case (value, newValue) =>
+ assert(value === newValue, s"Values do not match on param ${p.name}.")
+ }
+ } else {
+ assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
+ }
+ }
+
+ val load = instance.getClass.getMethod("load", classOf[String])
+ val another = load.invoke(instance, path).asInstanceOf[T]
+ assert(another.uid === instance.uid)
+ }
+}
+
+class MyParams(override val uid: String) extends Params with Writable {
+
+ final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc")
+ final val intParam: IntParam = new IntParam(this, "intParam", "doc")
+ final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc")
+ final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc")
+ final val longParam: LongParam = new LongParam(this, "longParam", "doc")
+ final val stringParam: Param[String] = new Param[String](this, "stringParam", "doc")
+ final val intArrayParam: IntArrayParam = new IntArrayParam(this, "intArrayParam", "doc")
+ final val doubleArrayParam: DoubleArrayParam =
+ new DoubleArrayParam(this, "doubleArrayParam", "doc")
+ final val stringArrayParam: StringArrayParam =
+ new StringArrayParam(this, "stringArrayParam", "doc")
+
+ setDefault(intParamWithDefault -> 0)
+ set(intParam -> 1)
+ set(floatParam -> 2.0f)
+ set(doubleParam -> 3.0)
+ set(longParam -> 4L)
+ set(stringParam -> "5")
+ set(intArrayParam -> Array(6, 7))
+ set(doubleArrayParam -> Array(8.0, 9.0))
+ set(stringArrayParam -> Array("10", "11"))
+
+ override def copy(extra: ParamMap): Params = defaultCopy(extra)
+
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+object MyParams extends Readable[MyParams] {
+
+ override def read: Reader[MyParams] = new DefaultParamsReader[MyParams]
+
+ override def load(path: String): MyParams = read.load(path)
+}
+
+class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest {
+
+ test("default read/write") {
+ val myParams = new MyParams("my_params")
+ testDefaultReadWrite(myParams)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala
new file mode 100644
index 0000000000..2742026a69
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.io.File
+
+import org.scalatest.{BeforeAndAfterAll, Suite}
+
+import org.apache.spark.util.Utils
+
+/**
+ * Trait that creates a temporary directory before all tests and deletes it after all.
+ */
+trait TempDirectory extends BeforeAndAfterAll { self: Suite =>
+
+ private var _tempDir: File = _
+
+ /** Returns the temporary directory as a [[File]] instance. */
+ protected def tempDir: File = _tempDir
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ _tempDir = Utils.createTempDir(this.getClass.getName)
+ }
+
+ override def afterAll(): Unit = {
+ Utils.deleteRecursively(_tempDir)
+ super.afterAll()
+ }
+}