aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-03-24 15:29:17 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-24 15:29:17 -0700
commit2cf46d5a96897d5f97b364db357d30566183c6e7 (patch)
tree79a9657a638e27bc59eefae85b47a35c4015fc8c /mllib/src/main
parentd283223a5a75c53970e72a1016e0b237856b5ea1 (diff)
downloadspark-2cf46d5a96897d5f97b364db357d30566183c6e7.tar.gz
spark-2cf46d5a96897d5f97b364db357d30566183c6e7.tar.bz2
spark-2cf46d5a96897d5f97b364db357d30566183c6e7.zip
[SPARK-11871] Add save/load for MLPC
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-11871 Add save/load for MLPC ## How was this patch tested? Test with Scala unit test Author: Xusen Yin <yinxusen@gmail.com> Closes #9854 from yinxusen/SPARK-11871.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala69
1 files changed, 66 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 719d1076fe..f6de5f2df4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -19,12 +19,14 @@ package org.apache.spark.ml.classification
import scala.collection.JavaConverters._
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
import org.apache.spark.ml.param.{IntArrayParam, IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasTol}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.DataFrame
@@ -110,7 +112,7 @@ private object LabelConverter {
class MultilayerPerceptronClassifier @Since("1.5.0") (
@Since("1.5.0") override val uid: String)
extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel]
- with MultilayerPerceptronParams {
+ with MultilayerPerceptronParams with DefaultParamsWritable {
@Since("1.5.0")
def this() = this(Identifiable.randomUID("mlpc"))
@@ -172,6 +174,14 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
}
}
+@Since("2.0.0")
+object MultilayerPerceptronClassifier
+ extends DefaultParamsReadable[MultilayerPerceptronClassifier] {
+
+ @Since("2.0.0")
+ override def load(path: String): MultilayerPerceptronClassifier = super.load(path)
+}
+
/**
* :: Experimental ::
* Classification model based on the Multilayer Perceptron.
@@ -188,7 +198,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
@Since("1.5.0") val layers: Array[Int],
@Since("1.5.0") val weights: Vector)
extends PredictionModel[Vector, MultilayerPerceptronClassificationModel]
- with Serializable {
+ with Serializable with MLWritable {
@Since("1.6.0")
override val numFeatures: Int = layers.head
@@ -214,4 +224,57 @@ class MultilayerPerceptronClassificationModel private[ml] (
override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = {
copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new MultilayerPerceptronClassificationModel.MultilayerPerceptronClassificationModelWriter(this)
+}
+
+@Since("2.0.0")
+object MultilayerPerceptronClassificationModel
+ extends MLReadable[MultilayerPerceptronClassificationModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[MultilayerPerceptronClassificationModel] =
+ new MultilayerPerceptronClassificationModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): MultilayerPerceptronClassificationModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[MultilayerPerceptronClassificationModel]] */
+ private[MultilayerPerceptronClassificationModel]
+ class MultilayerPerceptronClassificationModelWriter(
+ instance: MultilayerPerceptronClassificationModel) extends MLWriter {
+
+ private case class Data(layers: Array[Int], weights: Vector)
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: layers, weights
+ val data = Data(instance.layers, instance.weights)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class MultilayerPerceptronClassificationModelReader
+ extends MLReader[MultilayerPerceptronClassificationModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[MultilayerPerceptronClassificationModel].getName
+
+ override def load(path: String): MultilayerPerceptronClassificationModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath).select("layers", "weights").head()
+ val layers = data.getAs[Seq[Int]](0).toArray
+ val weights = data.getAs[Vector](1)
+ val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights)
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
}