diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-11-06 14:51:03 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-11-06 14:51:03 -0800 |
commit | c447c9d54603890db7399fb80adc9fae40b71f64 (patch) | |
tree | 0f8a339ee0b28a00944bea96879600315ab3ef17 /mllib/src/main/scala/org | |
parent | 3a652f691b220fada0286f8d0a562c5657973d4d (diff) | |
download | spark-c447c9d54603890db7399fb80adc9fae40b71f64.tar.gz spark-c447c9d54603890db7399fb80adc9fae40b71f64.tar.bz2 spark-c447c9d54603890db7399fb80adc9fae40b71f64.zip |
[SPARK-11217][ML] save/load for non-meta estimators and transformers
This PR implements the default save/load for non-meta estimators and transformers using the JSON serialization of param values. The saved metadata includes:
* class name
* uid
* timestamp
* paramMap
The save/load interface is similar to DataFrames. We use the current active context by default, which should be sufficient for most use cases.
~~~scala
instance.save("path")
instance.write.context(sqlContext).overwrite().save("path")
Instance.load("path")
~~~
The param handling is different from the design doc. We didn't save default and user-set params separately, and when we load it back, all parameters are user-set. This does cause issues. But it also cause other issues if we modify the default params.
TODOs:
* [x] Java test
* [ ] a follow-up PR to implement default save/load for all non-meta estimators and transformers
cc jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Closes #9454 from mengxr/SPARK-11217.
Diffstat (limited to 'mllib/src/main/scala/org')
3 files changed, 230 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index edad754436..e5c25574d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental final class Binarizer(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol { + extends Transformer with Writable with HasInputCol with HasOutputCol { def this() = this(Identifiable.randomUID("binarizer")) @@ -86,4 +86,11 @@ final class Binarizer(override val uid: String) } override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) + + override def write: Writer = new DefaultParamsWriter(this) +} + +object Binarizer extends Readable[Binarizer] { + + override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 8361406f87..c932570918 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -592,7 +592,7 @@ trait Params extends Identifiable with Serializable { /** * Sets a parameter in the embedded param map. */ - protected final def set[T](param: Param[T], value: T): this.type = { + final def set[T](param: Param[T], value: T): this.type = { set(param -> value) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala new file mode 100644 index 0000000000..ea790e0ddd --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.IOException + +import org.apache.hadoop.fs.{FileSystem, Path} +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param.{ParamPair, Params} +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils + +/** + * Trait for [[Writer]] and [[Reader]]. + */ +private[util] sealed trait BaseReadWrite { + private var optionSQLContext: Option[SQLContext] = None + + /** + * Sets the SQL context to use for saving/loading. + */ + @Since("1.6.0") + def context(sqlContext: SQLContext): this.type = { + optionSQLContext = Option(sqlContext) + this + } + + /** + * Returns the user-specified SQL context or the default. + */ + protected final def sqlContext: SQLContext = optionSQLContext.getOrElse { + SQLContext.getOrCreate(SparkContext.getOrCreate()) + } +} + +/** + * Abstract class for utility classes that can save ML instances. + */ +@Experimental +@Since("1.6.0") +abstract class Writer extends BaseReadWrite { + + protected var shouldOverwrite: Boolean = false + + /** + * Saves the ML instances to the input path. + */ + @Since("1.6.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + def save(path: String): Unit + + /** + * Overwrites if the output path already exists. + */ + @Since("1.6.0") + def overwrite(): this.type = { + shouldOverwrite = true + this + } + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) +} + +/** + * Trait for classes that provide [[Writer]]. + */ +@Since("1.6.0") +trait Writable { + + /** + * Returns a [[Writer]] instance for this ML instance. + */ + @Since("1.6.0") + def write: Writer + + /** + * Saves this ML instance to the input path, a shortcut of `write.save(path)`. + */ + @Since("1.6.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + def save(path: String): Unit = write.save(path) +} + +/** + * Abstract class for utility classes that can load ML instances. + * @tparam T ML instance type + */ +@Experimental +@Since("1.6.0") +abstract class Reader[T] extends BaseReadWrite { + + /** + * Loads the ML component from the input path. + */ + @Since("1.6.0") + def load(path: String): T + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) +} + +/** + * Trait for objects that provide [[Reader]]. + * @tparam T ML instance type + */ +@Experimental +@Since("1.6.0") +trait Readable[T] { + + /** + * Returns a [[Reader]] instance for this class. + */ + @Since("1.6.0") + def read: Reader[T] + + /** + * Reads an ML instance from the input path, a shortcut of `read.load(path)`. + */ + @Since("1.6.0") + def load(path: String): T = read.load(path) +} + +/** + * Default [[Writer]] implementation for transformers and estimators that contain basic + * (json4s-serializable) params and no data. This will not handle more complex params or types with + * data (e.g., models with coefficients). + * @param instance object to save + */ +private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging { + + /** + * Saves the ML component to the input path. + */ + override def save(path: String): Unit = { + val sc = sqlContext.sparkContext + + val hadoopConf = sc.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + val p = new Path(path) + if (fs.exists(p)) { + if (shouldOverwrite) { + logInfo(s"Path $path already exists. It will be overwritten.") + // TODO: Revert back to the original content if save is not successful. + fs.delete(p, true) + } else { + throw new IOException( + s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") + } + } + + val uid = instance.uid + val cls = instance.getClass.getName + val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] + val jsonParams = params.map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList + val metadata = ("class" -> cls) ~ + ("timestamp" -> System.currentTimeMillis()) ~ + ("uid" -> uid) ~ + ("paramMap" -> jsonParams) + val metadataPath = new Path(path, "metadata").toString + val metadataJson = compact(render(metadata)) + sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + } +} + +/** + * Default [[Reader]] implementation for transformers and estimators that contain basic + * (json4s-serializable) params and no data. This will not handle more complex params or types with + * data (e.g., models with coefficients). + * @tparam T ML instance type + */ +private[ml] class DefaultParamsReader[T] extends Reader[T] { + + /** + * Loads the ML component from the input path. + */ + override def load(path: String): T = { + implicit val format = DefaultFormats + val sc = sqlContext.sparkContext + val metadataPath = new Path(path, "metadata").toString + val metadataStr = sc.textFile(metadataPath, 1).first() + val metadata = parse(metadataStr) + val cls = Utils.classForName((metadata \ "class").extract[String]) + val uid = (metadata \ "uid").extract[String] + val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params] + (metadata \ "paramMap") match { + case JObject(pairs) => + pairs.foreach { case (paramName, jsonValue) => + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + instance.set(param, value) + } + case _ => + throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.") + } + instance.asInstanceOf[T] + } +} |