aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-11-06 14:51:03 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-11-06 14:51:03 -0800
commitc447c9d54603890db7399fb80adc9fae40b71f64 (patch)
tree0f8a339ee0b28a00944bea96879600315ab3ef17
parent3a652f691b220fada0286f8d0a562c5657973d4d (diff)
downloadspark-c447c9d54603890db7399fb80adc9fae40b71f64.tar.gz
spark-c447c9d54603890db7399fb80adc9fae40b71f64.tar.bz2
spark-c447c9d54603890db7399fb80adc9fae40b71f64.zip
[SPARK-11217][ML] save/load for non-meta estimators and transformers
This PR implements the default save/load for non-meta estimators and transformers using the JSON serialization of param values. The saved metadata includes: * class name * uid * timestamp * paramMap The save/load interface is similar to DataFrames. We use the current active context by default, which should be sufficient for most use cases. ~~~scala instance.save("path") instance.write.context(sqlContext).overwrite().save("path") Instance.load("path") ~~~ The param handling is different from the design doc. We didn't save default and user-set params separately, and when we load it back, all parameters are user-set. This does cause issues. But it also cause other issues if we modify the default params. TODOs: * [x] Java test * [ ] a follow-up PR to implement default save/load for all non-meta estimators and transformers cc jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #9454 from mengxr/SPARK-11217.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala220
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java74
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala110
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala45
7 files changed, 469 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index edad754436..e5c25574d4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}
@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
*/
@Experimental
final class Binarizer(override val uid: String)
- extends Transformer with HasInputCol with HasOutputCol {
+ extends Transformer with Writable with HasInputCol with HasOutputCol {
def this() = this(Identifiable.randomUID("binarizer"))
@@ -86,4 +86,11 @@ final class Binarizer(override val uid: String)
}
override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
+
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+object Binarizer extends Readable[Binarizer] {
+
+ override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer]
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 8361406f87..c932570918 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -592,7 +592,7 @@ trait Params extends Identifiable with Serializable {
/**
* Sets a parameter in the embedded param map.
*/
- protected final def set[T](param: Param[T], value: T): this.type = {
+ final def set[T](param: Param[T], value: T): this.type = {
set(param -> value)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
new file mode 100644
index 0000000000..ea790e0ddd
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.io.IOException
+
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.param.{ParamPair, Params}
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.util.Utils
+
+/**
+ * Trait for [[Writer]] and [[Reader]].
+ */
+private[util] sealed trait BaseReadWrite {
+ private var optionSQLContext: Option[SQLContext] = None
+
+ /**
+ * Sets the SQL context to use for saving/loading.
+ */
+ @Since("1.6.0")
+ def context(sqlContext: SQLContext): this.type = {
+ optionSQLContext = Option(sqlContext)
+ this
+ }
+
+ /**
+ * Returns the user-specified SQL context or the default.
+ */
+ protected final def sqlContext: SQLContext = optionSQLContext.getOrElse {
+ SQLContext.getOrCreate(SparkContext.getOrCreate())
+ }
+}
+
+/**
+ * Abstract class for utility classes that can save ML instances.
+ */
+@Experimental
+@Since("1.6.0")
+abstract class Writer extends BaseReadWrite {
+
+ protected var shouldOverwrite: Boolean = false
+
+ /**
+ * Saves the ML instances to the input path.
+ */
+ @Since("1.6.0")
+ @throws[IOException]("If the input path already exists but overwrite is not enabled.")
+ def save(path: String): Unit
+
+ /**
+ * Overwrites if the output path already exists.
+ */
+ @Since("1.6.0")
+ def overwrite(): this.type = {
+ shouldOverwrite = true
+ this
+ }
+
+ // override for Java compatibility
+ override def context(sqlContext: SQLContext): this.type = super.context(sqlContext)
+}
+
+/**
+ * Trait for classes that provide [[Writer]].
+ */
+@Since("1.6.0")
+trait Writable {
+
+ /**
+ * Returns a [[Writer]] instance for this ML instance.
+ */
+ @Since("1.6.0")
+ def write: Writer
+
+ /**
+ * Saves this ML instance to the input path, a shortcut of `write.save(path)`.
+ */
+ @Since("1.6.0")
+ @throws[IOException]("If the input path already exists but overwrite is not enabled.")
+ def save(path: String): Unit = write.save(path)
+}
+
+/**
+ * Abstract class for utility classes that can load ML instances.
+ * @tparam T ML instance type
+ */
+@Experimental
+@Since("1.6.0")
+abstract class Reader[T] extends BaseReadWrite {
+
+ /**
+ * Loads the ML component from the input path.
+ */
+ @Since("1.6.0")
+ def load(path: String): T
+
+ // override for Java compatibility
+ override def context(sqlContext: SQLContext): this.type = super.context(sqlContext)
+}
+
+/**
+ * Trait for objects that provide [[Reader]].
+ * @tparam T ML instance type
+ */
+@Experimental
+@Since("1.6.0")
+trait Readable[T] {
+
+ /**
+ * Returns a [[Reader]] instance for this class.
+ */
+ @Since("1.6.0")
+ def read: Reader[T]
+
+ /**
+ * Reads an ML instance from the input path, a shortcut of `read.load(path)`.
+ */
+ @Since("1.6.0")
+ def load(path: String): T = read.load(path)
+}
+
+/**
+ * Default [[Writer]] implementation for transformers and estimators that contain basic
+ * (json4s-serializable) params and no data. This will not handle more complex params or types with
+ * data (e.g., models with coefficients).
+ * @param instance object to save
+ */
+private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging {
+
+ /**
+ * Saves the ML component to the input path.
+ */
+ override def save(path: String): Unit = {
+ val sc = sqlContext.sparkContext
+
+ val hadoopConf = sc.hadoopConfiguration
+ val fs = FileSystem.get(hadoopConf)
+ val p = new Path(path)
+ if (fs.exists(p)) {
+ if (shouldOverwrite) {
+ logInfo(s"Path $path already exists. It will be overwritten.")
+ // TODO: Revert back to the original content if save is not successful.
+ fs.delete(p, true)
+ } else {
+ throw new IOException(
+ s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.")
+ }
+ }
+
+ val uid = instance.uid
+ val cls = instance.getClass.getName
+ val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
+ val jsonParams = params.map { case ParamPair(p, v) =>
+ p.name -> parse(p.jsonEncode(v))
+ }.toList
+ val metadata = ("class" -> cls) ~
+ ("timestamp" -> System.currentTimeMillis()) ~
+ ("uid" -> uid) ~
+ ("paramMap" -> jsonParams)
+ val metadataPath = new Path(path, "metadata").toString
+ val metadataJson = compact(render(metadata))
+ sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+ }
+}
+
+/**
+ * Default [[Reader]] implementation for transformers and estimators that contain basic
+ * (json4s-serializable) params and no data. This will not handle more complex params or types with
+ * data (e.g., models with coefficients).
+ * @tparam T ML instance type
+ */
+private[ml] class DefaultParamsReader[T] extends Reader[T] {
+
+ /**
+ * Loads the ML component from the input path.
+ */
+ override def load(path: String): T = {
+ implicit val format = DefaultFormats
+ val sc = sqlContext.sparkContext
+ val metadataPath = new Path(path, "metadata").toString
+ val metadataStr = sc.textFile(metadataPath, 1).first()
+ val metadata = parse(metadataStr)
+ val cls = Utils.classForName((metadata \ "class").extract[String])
+ val uid = (metadata \ "uid").extract[String]
+ val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params]
+ (metadata \ "paramMap") match {
+ case JObject(pairs) =>
+ pairs.foreach { case (paramName, jsonValue) =>
+ val param = instance.getParam(paramName)
+ val value = param.jsonDecode(compact(render(jsonValue)))
+ instance.set(param, value)
+ }
+ case _ =>
+ throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.")
+ }
+ instance.asInstanceOf[T]
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
new file mode 100644
index 0000000000..c39538014b
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util;
+
+import java.io.File;
+import java.io.IOException;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.util.Utils;
+
+public class JavaDefaultReadWriteSuite {
+
+ JavaSparkContext jsc = null;
+ File tempDir = null;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite");
+ tempDir = Utils.createTempDir(
+ System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite");
+ }
+
+ @After
+ public void tearDown() {
+ if (jsc != null) {
+ jsc.stop();
+ jsc = null;
+ }
+ Utils.deleteRecursively(tempDir);
+ }
+
+ @Test
+ public void testDefaultReadWrite() throws IOException {
+ String uid = "my_params";
+ MyParams instance = new MyParams(uid);
+ instance.set(instance.intParam(), 2);
+ String outputPath = new File(tempDir, uid).getPath();
+ instance.save(outputPath);
+ try {
+ instance.save(outputPath);
+ Assert.fail(
+ "Write without overwrite enabled should fail if the output directory already exists.");
+ } catch (IOException e) {
+ // expected
+ }
+ SQLContext sqlContext = new SQLContext(jsc);
+ instance.write().context(sqlContext).overwrite().save(outputPath);
+ MyParams newInstance = MyParams.load(outputPath);
+ Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
+ Assert.assertEquals("Params should be preserved.",
+ 2, newInstance.getOrDefault(newInstance.intParam()));
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 2086043983..9dfa1439cc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -19,10 +19,11 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
-class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var data: Array[Double] = _
@@ -66,4 +67,12 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(x === y, "The feature value is not correct after binarization.")
}
}
+
+ test("read/write") {
+ val binarizer = new Binarizer()
+ .setInputCol("feature")
+ .setOutputCol("binarized_feature")
+ .setThreshold(0.1)
+ testDefaultReadWrite(binarizer)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
new file mode 100644
index 0000000000..4545b0f281
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.io.{File, IOException}
+
+import org.scalatest.Suite
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
+
+ /**
+ * Checks "overwrite" option and params.
+ * @param instance ML instance to test saving/loading
+ * @tparam T ML instance type
+ */
+ def testDefaultReadWrite[T <: Params with Writable](instance: T): Unit = {
+ val uid = instance.uid
+ val path = new File(tempDir, uid).getPath
+
+ instance.save(path)
+ intercept[IOException] {
+ instance.save(path)
+ }
+ instance.write.overwrite().save(path)
+ val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[Reader[T]]
+ val newInstance = loader.load(path)
+
+ assert(newInstance.uid === instance.uid)
+ instance.params.foreach { p =>
+ if (instance.isDefined(p)) {
+ (instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
+ case (Array(values), Array(newValues)) =>
+ assert(values === newValues, s"Values do not match on param ${p.name}.")
+ case (value, newValue) =>
+ assert(value === newValue, s"Values do not match on param ${p.name}.")
+ }
+ } else {
+ assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
+ }
+ }
+
+ val load = instance.getClass.getMethod("load", classOf[String])
+ val another = load.invoke(instance, path).asInstanceOf[T]
+ assert(another.uid === instance.uid)
+ }
+}
+
+class MyParams(override val uid: String) extends Params with Writable {
+
+ final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc")
+ final val intParam: IntParam = new IntParam(this, "intParam", "doc")
+ final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc")
+ final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc")
+ final val longParam: LongParam = new LongParam(this, "longParam", "doc")
+ final val stringParam: Param[String] = new Param[String](this, "stringParam", "doc")
+ final val intArrayParam: IntArrayParam = new IntArrayParam(this, "intArrayParam", "doc")
+ final val doubleArrayParam: DoubleArrayParam =
+ new DoubleArrayParam(this, "doubleArrayParam", "doc")
+ final val stringArrayParam: StringArrayParam =
+ new StringArrayParam(this, "stringArrayParam", "doc")
+
+ setDefault(intParamWithDefault -> 0)
+ set(intParam -> 1)
+ set(floatParam -> 2.0f)
+ set(doubleParam -> 3.0)
+ set(longParam -> 4L)
+ set(stringParam -> "5")
+ set(intArrayParam -> Array(6, 7))
+ set(doubleArrayParam -> Array(8.0, 9.0))
+ set(stringArrayParam -> Array("10", "11"))
+
+ override def copy(extra: ParamMap): Params = defaultCopy(extra)
+
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+object MyParams extends Readable[MyParams] {
+
+ override def read: Reader[MyParams] = new DefaultParamsReader[MyParams]
+
+ override def load(path: String): MyParams = read.load(path)
+}
+
+class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest {
+
+ test("default read/write") {
+ val myParams = new MyParams("my_params")
+ testDefaultReadWrite(myParams)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala
new file mode 100644
index 0000000000..2742026a69
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.io.File
+
+import org.scalatest.{BeforeAndAfterAll, Suite}
+
+import org.apache.spark.util.Utils
+
+/**
+ * Trait that creates a temporary directory before all tests and deletes it after all.
+ */
+trait TempDirectory extends BeforeAndAfterAll { self: Suite =>
+
+ private var _tempDir: File = _
+
+ /** Returns the temporary directory as a [[File]] instance. */
+ protected def tempDir: File = _tempDir
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ _tempDir = Utils.createTempDir(this.getClass.getName)
+ }
+
+ override def afterAll(): Unit = {
+ Utils.deleteRecursively(_tempDir)
+ super.afterAll()
+ }
+}