diff options
author | Xusen Yin <yinxusen@gmail.com> | 2016-03-24 15:29:17 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-03-24 15:29:17 -0700 |
commit | 2cf46d5a96897d5f97b364db357d30566183c6e7 (patch) | |
tree | 79a9657a638e27bc59eefae85b47a35c4015fc8c /mllib/src/main/scala | |
parent | d283223a5a75c53970e72a1016e0b237856b5ea1 (diff) | |
download | spark-2cf46d5a96897d5f97b364db357d30566183c6e7.tar.gz spark-2cf46d5a96897d5f97b364db357d30566183c6e7.tar.bz2 spark-2cf46d5a96897d5f97b364db357d30566183c6e7.zip |
[SPARK-11871] Add save/load for MLPC
## What changes were proposed in this pull request?
https://issues.apache.org/jira/browse/SPARK-11871
Add save/load for MLPC
## How was this patch tested?
Test with Scala unit test
Author: Xusen Yin <yinxusen@gmail.com>
Closes #9854 from yinxusen/SPARK-11871.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala | 69 |
1 files changed, 66 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 719d1076fe..f6de5f2df4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -19,12 +19,14 @@ package org.apache.spark.ml.classification import scala.collection.JavaConverters._ +import org.apache.hadoop.fs.Path + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} import org.apache.spark.ml.param.{IntArrayParam, IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasTol} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql.DataFrame @@ -110,7 +112,7 @@ private object LabelConverter { class MultilayerPerceptronClassifier @Since("1.5.0") ( @Since("1.5.0") override val uid: String) extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel] - with MultilayerPerceptronParams { + with MultilayerPerceptronParams with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("mlpc")) @@ -172,6 +174,14 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( } } +@Since("2.0.0") +object MultilayerPerceptronClassifier + extends DefaultParamsReadable[MultilayerPerceptronClassifier] { + + @Since("2.0.0") + override def load(path: String): MultilayerPerceptronClassifier = super.load(path) +} + /** * :: Experimental :: * Classification model based on the Multilayer Perceptron. @@ -188,7 +198,7 @@ class MultilayerPerceptronClassificationModel private[ml] ( @Since("1.5.0") val layers: Array[Int], @Since("1.5.0") val weights: Vector) extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] - with Serializable { + with Serializable with MLWritable { @Since("1.6.0") override val numFeatures: Int = layers.head @@ -214,4 +224,57 @@ class MultilayerPerceptronClassificationModel private[ml] ( override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) } + + @Since("2.0.0") + override def write: MLWriter = + new MultilayerPerceptronClassificationModel.MultilayerPerceptronClassificationModelWriter(this) +} + +@Since("2.0.0") +object MultilayerPerceptronClassificationModel + extends MLReadable[MultilayerPerceptronClassificationModel] { + + @Since("2.0.0") + override def read: MLReader[MultilayerPerceptronClassificationModel] = + new MultilayerPerceptronClassificationModelReader + + @Since("2.0.0") + override def load(path: String): MultilayerPerceptronClassificationModel = super.load(path) + + /** [[MLWriter]] instance for [[MultilayerPerceptronClassificationModel]] */ + private[MultilayerPerceptronClassificationModel] + class MultilayerPerceptronClassificationModelWriter( + instance: MultilayerPerceptronClassificationModel) extends MLWriter { + + private case class Data(layers: Array[Int], weights: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: layers, weights + val data = Data(instance.layers, instance.weights) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class MultilayerPerceptronClassificationModelReader + extends MLReader[MultilayerPerceptronClassificationModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[MultilayerPerceptronClassificationModel].getName + + override def load(path: String): MultilayerPerceptronClassificationModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath).select("layers", "weights").head() + val layers = data.getAs[Seq[Int]](0).toArray + val weights = data.getAs[Vector](1) + val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } |