aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-03-24 15:29:17 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-24 15:29:17 -0700
commit2cf46d5a96897d5f97b364db357d30566183c6e7 (patch)
tree79a9657a638e27bc59eefae85b47a35c4015fc8c /mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
parentd283223a5a75c53970e72a1016e0b237856b5ea1 (diff)
downloadspark-2cf46d5a96897d5f97b364db357d30566183c6e7.tar.gz
spark-2cf46d5a96897d5f97b364db357d30566183c6e7.tar.bz2
spark-2cf46d5a96897d5f97b364db357d30566183c6e7.zip
[SPARK-11871] Add save/load for MLPC
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-11871 Add save/load for MLPC ## How was this patch tested? Test with Scala unit test Author: Xusen Yin <yinxusen@gmail.com> Closes #9854 from yinxusen/SPARK-11871.
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala43
1 files changed, 37 insertions, 6 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 602b5a8116..5df8e6a847 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -18,31 +18,40 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{DataFrame, Row}
-class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+class MultilayerPerceptronClassifierSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- test("XOR function learning as binary classification problem with two outputs.") {
- val dataFrame = sqlContext.createDataFrame(Seq(
+ @transient var dataset: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ dataset = sqlContext.createDataFrame(Seq(
(Vectors.dense(0.0, 0.0), 0.0),
(Vectors.dense(0.0, 1.0), 1.0),
(Vectors.dense(1.0, 0.0), 1.0),
(Vectors.dense(1.0, 1.0), 0.0))
).toDF("features", "label")
+ }
+
+ test("XOR function learning as binary classification problem with two outputs.") {
val layers = Array[Int](2, 5, 2)
val trainer = new MultilayerPerceptronClassifier()
.setLayers(layers)
.setBlockSize(1)
.setSeed(11L)
.setMaxIter(100)
- val model = trainer.fit(dataFrame)
- val result = model.transform(dataFrame)
+ val model = trainer.fit(dataset)
+ val result = model.transform(dataset)
val predictionAndLabels = result.select("prediction", "label").collect()
predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
assert(p == l)
@@ -92,4 +101,26 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp
val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels)
assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100)
}
+
+ test("read/write: MultilayerPerceptronClassifier") {
+ val mlp = new MultilayerPerceptronClassifier()
+ .setLayers(Array(2, 3, 2))
+ .setMaxIter(5)
+ .setBlockSize(2)
+ .setSeed(42)
+ .setTol(0.1)
+ .setFeaturesCol("myFeatures")
+ .setLabelCol("myLabel")
+ .setPredictionCol("myPrediction")
+
+ testDefaultReadWrite(mlp, testParams = true)
+ }
+
+ test("read/write: MultilayerPerceptronClassificationModel") {
+ val mlp = new MultilayerPerceptronClassifier().setLayers(Array(2, 3, 2)).setMaxIter(5)
+ val mlpModel = mlp.fit(dataset)
+ val newMlpModel = testDefaultReadWrite(mlpModel, testParams = true)
+ assert(newMlpModel.layers === mlpModel.layers)
+ assert(newMlpModel.weights === mlpModel.weights)
+ }
}