diff options
author | Xusen Yin <yinxusen@gmail.com> | 2016-03-24 15:29:17 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-03-24 15:29:17 -0700 |
commit | 2cf46d5a96897d5f97b364db357d30566183c6e7 (patch) | |
tree | 79a9657a638e27bc59eefae85b47a35c4015fc8c /mllib/src/test | |
parent | d283223a5a75c53970e72a1016e0b237856b5ea1 (diff) | |
download | spark-2cf46d5a96897d5f97b364db357d30566183c6e7.tar.gz spark-2cf46d5a96897d5f97b364db357d30566183c6e7.tar.bz2 spark-2cf46d5a96897d5f97b364db357d30566183c6e7.zip |
[SPARK-11871] Add save/load for MLPC
## What changes were proposed in this pull request?
https://issues.apache.org/jira/browse/SPARK-11871
Add save/load for MLPC
## How was this patch tested?
Test with Scala unit test
Author: Xusen Yin <yinxusen@gmail.com>
Closes #9854 from yinxusen/SPARK-11871.
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala | 43 |
1 files changed, 37 insertions, 6 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 602b5a8116..5df8e6a847 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -18,31 +18,40 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Row} -class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { +class MultilayerPerceptronClassifierSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - test("XOR function learning as binary classification problem with two outputs.") { - val dataFrame = sqlContext.createDataFrame(Seq( + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + dataset = sqlContext.createDataFrame(Seq( (Vectors.dense(0.0, 0.0), 0.0), (Vectors.dense(0.0, 1.0), 1.0), (Vectors.dense(1.0, 0.0), 1.0), (Vectors.dense(1.0, 1.0), 0.0)) ).toDF("features", "label") + } + + test("XOR function learning as binary classification problem with two outputs.") { val layers = Array[Int](2, 5, 2) val trainer = new MultilayerPerceptronClassifier() .setLayers(layers) .setBlockSize(1) .setSeed(11L) .setMaxIter(100) - val model = trainer.fit(dataFrame) - val result = model.transform(dataFrame) + val model = trainer.fit(dataset) + val result = model.transform(dataset) val predictionAndLabels = result.select("prediction", "label").collect() predictionAndLabels.foreach { case Row(p: Double, l: Double) => assert(p == l) @@ -92,4 +101,26 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100) } + + test("read/write: MultilayerPerceptronClassifier") { + val mlp = new MultilayerPerceptronClassifier() + .setLayers(Array(2, 3, 2)) + .setMaxIter(5) + .setBlockSize(2) + .setSeed(42) + .setTol(0.1) + .setFeaturesCol("myFeatures") + .setLabelCol("myLabel") + .setPredictionCol("myPrediction") + + testDefaultReadWrite(mlp, testParams = true) + } + + test("read/write: MultilayerPerceptronClassificationModel") { + val mlp = new MultilayerPerceptronClassifier().setLayers(Array(2, 3, 2)).setMaxIter(5) + val mlpModel = mlp.fit(dataset) + val newMlpModel = testDefaultReadWrite(mlpModel, testParams = true) + assert(newMlpModel.layers === mlpModel.layers) + assert(newMlpModel.weights === mlpModel.weights) + } } |