aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala69
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala43
2 files changed, 103 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 719d1076fe..f6de5f2df4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -19,12 +19,14 @@ package org.apache.spark.ml.classification
import scala.collection.JavaConverters._
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
import org.apache.spark.ml.param.{IntArrayParam, IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasTol}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.DataFrame
@@ -110,7 +112,7 @@ private object LabelConverter {
class MultilayerPerceptronClassifier @Since("1.5.0") (
@Since("1.5.0") override val uid: String)
extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel]
- with MultilayerPerceptronParams {
+ with MultilayerPerceptronParams with DefaultParamsWritable {
@Since("1.5.0")
def this() = this(Identifiable.randomUID("mlpc"))
@@ -172,6 +174,14 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
}
}
+@Since("2.0.0")
+object MultilayerPerceptronClassifier
+ extends DefaultParamsReadable[MultilayerPerceptronClassifier] {
+
+ @Since("2.0.0")
+ override def load(path: String): MultilayerPerceptronClassifier = super.load(path)
+}
+
/**
* :: Experimental ::
* Classification model based on the Multilayer Perceptron.
@@ -188,7 +198,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
@Since("1.5.0") val layers: Array[Int],
@Since("1.5.0") val weights: Vector)
extends PredictionModel[Vector, MultilayerPerceptronClassificationModel]
- with Serializable {
+ with Serializable with MLWritable {
@Since("1.6.0")
override val numFeatures: Int = layers.head
@@ -214,4 +224,57 @@ class MultilayerPerceptronClassificationModel private[ml] (
override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = {
copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new MultilayerPerceptronClassificationModel.MultilayerPerceptronClassificationModelWriter(this)
+}
+
+@Since("2.0.0")
+object MultilayerPerceptronClassificationModel
+ extends MLReadable[MultilayerPerceptronClassificationModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[MultilayerPerceptronClassificationModel] =
+ new MultilayerPerceptronClassificationModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): MultilayerPerceptronClassificationModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[MultilayerPerceptronClassificationModel]] */
+ private[MultilayerPerceptronClassificationModel]
+ class MultilayerPerceptronClassificationModelWriter(
+ instance: MultilayerPerceptronClassificationModel) extends MLWriter {
+
+ private case class Data(layers: Array[Int], weights: Vector)
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: layers, weights
+ val data = Data(instance.layers, instance.weights)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class MultilayerPerceptronClassificationModelReader
+ extends MLReader[MultilayerPerceptronClassificationModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[MultilayerPerceptronClassificationModel].getName
+
+ override def load(path: String): MultilayerPerceptronClassificationModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath).select("layers", "weights").head()
+ val layers = data.getAs[Seq[Int]](0).toArray
+ val weights = data.getAs[Vector](1)
+ val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights)
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 602b5a8116..5df8e6a847 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -18,31 +18,40 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{DataFrame, Row}
-class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+class MultilayerPerceptronClassifierSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- test("XOR function learning as binary classification problem with two outputs.") {
- val dataFrame = sqlContext.createDataFrame(Seq(
+ @transient var dataset: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ dataset = sqlContext.createDataFrame(Seq(
(Vectors.dense(0.0, 0.0), 0.0),
(Vectors.dense(0.0, 1.0), 1.0),
(Vectors.dense(1.0, 0.0), 1.0),
(Vectors.dense(1.0, 1.0), 0.0))
).toDF("features", "label")
+ }
+
+ test("XOR function learning as binary classification problem with two outputs.") {
val layers = Array[Int](2, 5, 2)
val trainer = new MultilayerPerceptronClassifier()
.setLayers(layers)
.setBlockSize(1)
.setSeed(11L)
.setMaxIter(100)
- val model = trainer.fit(dataFrame)
- val result = model.transform(dataFrame)
+ val model = trainer.fit(dataset)
+ val result = model.transform(dataset)
val predictionAndLabels = result.select("prediction", "label").collect()
predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
assert(p == l)
@@ -92,4 +101,26 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp
val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels)
assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100)
}
+
+ test("read/write: MultilayerPerceptronClassifier") {
+ val mlp = new MultilayerPerceptronClassifier()
+ .setLayers(Array(2, 3, 2))
+ .setMaxIter(5)
+ .setBlockSize(2)
+ .setSeed(42)
+ .setTol(0.1)
+ .setFeaturesCol("myFeatures")
+ .setLabelCol("myLabel")
+ .setPredictionCol("myPrediction")
+
+ testDefaultReadWrite(mlp, testParams = true)
+ }
+
+ test("read/write: MultilayerPerceptronClassificationModel") {
+ val mlp = new MultilayerPerceptronClassifier().setLayers(Array(2, 3, 2)).setMaxIter(5)
+ val mlpModel = mlp.fit(dataset)
+ val newMlpModel = testDefaultReadWrite(mlpModel, testParams = true)
+ assert(newMlpModel.layers === mlpModel.layers)
+ assert(newMlpModel.weights === mlpModel.weights)
+ }
}