aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-11-06 14:51:03 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-11-06 14:51:03 -0800
commitc447c9d54603890db7399fb80adc9fae40b71f64 (patch)
tree0f8a339ee0b28a00944bea96879600315ab3ef17 /mllib/src/main/scala/org
parent3a652f691b220fada0286f8d0a562c5657973d4d (diff)
downloadspark-c447c9d54603890db7399fb80adc9fae40b71f64.tar.gz
spark-c447c9d54603890db7399fb80adc9fae40b71f64.tar.bz2
spark-c447c9d54603890db7399fb80adc9fae40b71f64.zip
[SPARK-11217][ML] save/load for non-meta estimators and transformers
This PR implements the default save/load for non-meta estimators and transformers using the JSON serialization of param values. The saved metadata includes: * class name * uid * timestamp * paramMap The save/load interface is similar to DataFrames. We use the current active context by default, which should be sufficient for most use cases. ~~~scala instance.save("path") instance.write.context(sqlContext).overwrite().save("path") Instance.load("path") ~~~ The param handling is different from the design doc. We didn't save default and user-set params separately, and when we load it back, all parameters are user-set. This does cause issues. But it also cause other issues if we modify the default params. TODOs: * [x] Java test * [ ] a follow-up PR to implement default save/load for all non-meta estimators and transformers cc jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #9454 from mengxr/SPARK-11217.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala220
3 files changed, 230 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index edad754436..e5c25574d4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}
@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
*/
@Experimental
final class Binarizer(override val uid: String)
- extends Transformer with HasInputCol with HasOutputCol {
+ extends Transformer with Writable with HasInputCol with HasOutputCol {
def this() = this(Identifiable.randomUID("binarizer"))
@@ -86,4 +86,11 @@ final class Binarizer(override val uid: String)
}
override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
+
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+object Binarizer extends Readable[Binarizer] {
+
+ override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer]
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 8361406f87..c932570918 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -592,7 +592,7 @@ trait Params extends Identifiable with Serializable {
/**
* Sets a parameter in the embedded param map.
*/
- protected final def set[T](param: Param[T], value: T): this.type = {
+ final def set[T](param: Param[T], value: T): this.type = {
set(param -> value)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
new file mode 100644
index 0000000000..ea790e0ddd
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.io.IOException
+
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.param.{ParamPair, Params}
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.util.Utils
+
+/**
+ * Trait for [[Writer]] and [[Reader]].
+ */
+private[util] sealed trait BaseReadWrite {
+ private var optionSQLContext: Option[SQLContext] = None
+
+ /**
+ * Sets the SQL context to use for saving/loading.
+ */
+ @Since("1.6.0")
+ def context(sqlContext: SQLContext): this.type = {
+ optionSQLContext = Option(sqlContext)
+ this
+ }
+
+ /**
+ * Returns the user-specified SQL context or the default.
+ */
+ protected final def sqlContext: SQLContext = optionSQLContext.getOrElse {
+ SQLContext.getOrCreate(SparkContext.getOrCreate())
+ }
+}
+
+/**
+ * Abstract class for utility classes that can save ML instances.
+ */
+@Experimental
+@Since("1.6.0")
+abstract class Writer extends BaseReadWrite {
+
+ protected var shouldOverwrite: Boolean = false
+
+ /**
+ * Saves the ML instances to the input path.
+ */
+ @Since("1.6.0")
+ @throws[IOException]("If the input path already exists but overwrite is not enabled.")
+ def save(path: String): Unit
+
+ /**
+ * Overwrites if the output path already exists.
+ */
+ @Since("1.6.0")
+ def overwrite(): this.type = {
+ shouldOverwrite = true
+ this
+ }
+
+ // override for Java compatibility
+ override def context(sqlContext: SQLContext): this.type = super.context(sqlContext)
+}
+
+/**
+ * Trait for classes that provide [[Writer]].
+ */
+@Since("1.6.0")
+trait Writable {
+
+ /**
+ * Returns a [[Writer]] instance for this ML instance.
+ */
+ @Since("1.6.0")
+ def write: Writer
+
+ /**
+ * Saves this ML instance to the input path, a shortcut of `write.save(path)`.
+ */
+ @Since("1.6.0")
+ @throws[IOException]("If the input path already exists but overwrite is not enabled.")
+ def save(path: String): Unit = write.save(path)
+}
+
+/**
+ * Abstract class for utility classes that can load ML instances.
+ * @tparam T ML instance type
+ */
+@Experimental
+@Since("1.6.0")
+abstract class Reader[T] extends BaseReadWrite {
+
+ /**
+ * Loads the ML component from the input path.
+ */
+ @Since("1.6.0")
+ def load(path: String): T
+
+ // override for Java compatibility
+ override def context(sqlContext: SQLContext): this.type = super.context(sqlContext)
+}
+
+/**
+ * Trait for objects that provide [[Reader]].
+ * @tparam T ML instance type
+ */
+@Experimental
+@Since("1.6.0")
+trait Readable[T] {
+
+ /**
+ * Returns a [[Reader]] instance for this class.
+ */
+ @Since("1.6.0")
+ def read: Reader[T]
+
+ /**
+ * Reads an ML instance from the input path, a shortcut of `read.load(path)`.
+ */
+ @Since("1.6.0")
+ def load(path: String): T = read.load(path)
+}
+
+/**
+ * Default [[Writer]] implementation for transformers and estimators that contain basic
+ * (json4s-serializable) params and no data. This will not handle more complex params or types with
+ * data (e.g., models with coefficients).
+ * @param instance object to save
+ */
+private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging {
+
+ /**
+ * Saves the ML component to the input path.
+ */
+ override def save(path: String): Unit = {
+ val sc = sqlContext.sparkContext
+
+ val hadoopConf = sc.hadoopConfiguration
+ val fs = FileSystem.get(hadoopConf)
+ val p = new Path(path)
+ if (fs.exists(p)) {
+ if (shouldOverwrite) {
+ logInfo(s"Path $path already exists. It will be overwritten.")
+ // TODO: Revert back to the original content if save is not successful.
+ fs.delete(p, true)
+ } else {
+ throw new IOException(
+ s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.")
+ }
+ }
+
+ val uid = instance.uid
+ val cls = instance.getClass.getName
+ val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
+ val jsonParams = params.map { case ParamPair(p, v) =>
+ p.name -> parse(p.jsonEncode(v))
+ }.toList
+ val metadata = ("class" -> cls) ~
+ ("timestamp" -> System.currentTimeMillis()) ~
+ ("uid" -> uid) ~
+ ("paramMap" -> jsonParams)
+ val metadataPath = new Path(path, "metadata").toString
+ val metadataJson = compact(render(metadata))
+ sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+ }
+}
+
+/**
+ * Default [[Reader]] implementation for transformers and estimators that contain basic
+ * (json4s-serializable) params and no data. This will not handle more complex params or types with
+ * data (e.g., models with coefficients).
+ * @tparam T ML instance type
+ */
+private[ml] class DefaultParamsReader[T] extends Reader[T] {
+
+ /**
+ * Loads the ML component from the input path.
+ */
+ override def load(path: String): T = {
+ implicit val format = DefaultFormats
+ val sc = sqlContext.sparkContext
+ val metadataPath = new Path(path, "metadata").toString
+ val metadataStr = sc.textFile(metadataPath, 1).first()
+ val metadata = parse(metadataStr)
+ val cls = Utils.classForName((metadata \ "class").extract[String])
+ val uid = (metadata \ "uid").extract[String]
+ val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params]
+ (metadata \ "paramMap") match {
+ case JObject(pairs) =>
+ pairs.foreach { case (paramName, jsonValue) =>
+ val param = instance.getParam(paramName)
+ val value = param.jsonDecode(compact(render(jsonValue)))
+ instance.set(param, value)
+ }
+ case _ =>
+ throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.")
+ }
+ instance.asInstanceOf[T]
+ }
+}