aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorAlexander Ulanov <nashb@yandex.ru>2015-07-31 11:22:40 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-31 11:23:30 -0700
commit6add4eddb39e7748a87da3e921ea3c7881d30a82 (patch)
tree92fecca0d3008e5e537a78dcf87349b89a32a8dc /mllib/src/test
parent0024da9157ba12ec84883a78441fa6835c1d0042 (diff)
downloadspark-6add4eddb39e7748a87da3e921ea3c7881d30a82.tar.gz
spark-6add4eddb39e7748a87da3e921ea3c7881d30a82.tar.bz2
spark-6add4eddb39e7748a87da3e921ea3c7881d30a82.zip
[SPARK-9471] [ML] Multilayer Perceptron
This pull request contains the following feature for ML: - Multilayer Perceptron classifier This implementation is based on our initial pull request with bgreeven: https://github.com/apache/spark/pull/1290 and inspired by very insightful suggestions from mengxr and witgo (I would like to thank all other people from the mentioned thread for useful discussions). The original code was extensively tested and benchmarked. Since then, I've addressed two main requirements that prevented the code from merging into the main branch: - Extensible interface, so it will be easy to implement new types of networks - Main building blocks are traits `Layer` and `LayerModel`. They are used for constructing layers of ANN. New layers can be added by extending the `Layer` and `LayerModel` traits. These traits are private in this release in order to save path to improve them based on community feedback - Back propagation is implemented in general form, so there is no need to change it (optimization algorithm) when new layers are implemented - Speed and scalability: this implementation has to be comparable in terms of speed to the state of the art single node implementations. - The developed benchmark for large ANN shows that the proposed code is on par with C++ CPU implementation and scales nicely with the number of workers. Details can be found here: https://github.com/avulanov/ann-benchmark - DBN and RBM by witgo https://github.com/witgo/spark/tree/ann-interface-gemm-dbn - Dropout https://github.com/avulanov/spark/tree/ann-interface-gemm mengxr and dbtsai kindly agreed to perform code review. Author: Alexander Ulanov <nashb@yandex.ru> Author: Bert Greevenbosch <opensrc@bertgreevenbosch.nl> Closes #7621 from avulanov/SPARK-2352-ann and squashes the following commits: 4806b6f [Alexander Ulanov] Addressing reviewers comments. a7e7951 [Alexander Ulanov] Default blockSize: 100. Added documentation to blockSize parameter and DataStacker class f69bb3d [Alexander Ulanov] Addressing reviewers comments. 374bea6 [Alexander Ulanov] Moving ANN to ML package. GradientDescent constructor is now spark private. 43b0ae2 [Alexander Ulanov] Addressing reviewers comments. Adding multiclass test. 9d18469 [Alexander Ulanov] Addressing reviewers comments: unnecessary copy of data in predict 35125ab [Alexander Ulanov] Style fix in tests e191301 [Alexander Ulanov] Apache header a226133 [Alexander Ulanov] Multilayer Perceptron regressor and classifier
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala91
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala91
2 files changed, 182 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
new file mode 100644
index 0000000000..1292e57d7c
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.ann
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+
+
+class ANNSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ // TODO: test for weights comparison with Weka MLP
+ test("ANN with Sigmoid learns XOR function with LBFGS optimizer") {
+ val inputs = Array(
+ Array(0.0, 0.0),
+ Array(0.0, 1.0),
+ Array(1.0, 0.0),
+ Array(1.0, 1.0)
+ )
+ val outputs = Array(0.0, 1.0, 1.0, 0.0)
+ val data = inputs.zip(outputs).map { case (features, label) =>
+ (Vectors.dense(features), Vectors.dense(label))
+ }
+ val rddData = sc.parallelize(data, 1)
+ val hiddenLayersTopology = Array(5)
+ val dataSample = rddData.first()
+ val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
+ val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
+ val initialWeights = FeedForwardModel(topology, 23124).weights()
+ val trainer = new FeedForwardTrainer(topology, 2, 1)
+ trainer.setWeights(initialWeights)
+ trainer.LBFGSOptimizer.setNumIterations(20)
+ val model = trainer.train(rddData)
+ val predictionAndLabels = rddData.map { case (input, label) =>
+ (model.predict(input)(0), label(0))
+ }.collect()
+ predictionAndLabels.foreach { case (p, l) =>
+ assert(math.round(p) === l)
+ }
+ }
+
+ test("ANN with SoftMax learns XOR function with 2-bit output and batch GD optimizer") {
+ val inputs = Array(
+ Array(0.0, 0.0),
+ Array(0.0, 1.0),
+ Array(1.0, 0.0),
+ Array(1.0, 1.0)
+ )
+ val outputs = Array(
+ Array(1.0, 0.0),
+ Array(0.0, 1.0),
+ Array(0.0, 1.0),
+ Array(1.0, 0.0)
+ )
+ val data = inputs.zip(outputs).map { case (features, label) =>
+ (Vectors.dense(features), Vectors.dense(label))
+ }
+ val rddData = sc.parallelize(data, 1)
+ val hiddenLayersTopology = Array(5)
+ val dataSample = rddData.first()
+ val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
+ val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
+ val initialWeights = FeedForwardModel(topology, 23124).weights()
+ val trainer = new FeedForwardTrainer(topology, 2, 2)
+ trainer.SGDOptimizer.setNumIterations(2000)
+ trainer.setWeights(initialWeights)
+ val model = trainer.train(rddData)
+ val predictionAndLabels = rddData.map { case (input, label) =>
+ (model.predict(input), label)
+ }.collect()
+ predictionAndLabels.foreach { case (p, l) =>
+ assert(p ~== l absTol 0.5)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
new file mode 100644
index 0000000000..ddc948f65d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.classification.LogisticRegressionSuite._
+import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.Row
+
+class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("XOR function learning as binary classification problem with two outputs.") {
+ val dataFrame = sqlContext.createDataFrame(Seq(
+ (Vectors.dense(0.0, 0.0), 0.0),
+ (Vectors.dense(0.0, 1.0), 1.0),
+ (Vectors.dense(1.0, 0.0), 1.0),
+ (Vectors.dense(1.0, 1.0), 0.0))
+ ).toDF("features", "label")
+ val layers = Array[Int](2, 5, 2)
+ val trainer = new MultilayerPerceptronClassifier()
+ .setLayers(layers)
+ .setBlockSize(1)
+ .setSeed(11L)
+ .setMaxIter(100)
+ val model = trainer.fit(dataFrame)
+ val result = model.transform(dataFrame)
+ val predictionAndLabels = result.select("prediction", "label").collect()
+ predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
+ assert(p == l)
+ }
+ }
+
+ // TODO: implement a more rigorous test
+ test("3 class classification with 2 hidden layers") {
+ val nPoints = 1000
+
+ // The following weights are taken from OneVsRestSuite.scala
+ // they represent 3-class iris dataset
+ val weights = Array(
+ -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
+ -0.16624, -0.84355, -0.048509, -0.301789, 4.170682)
+
+ val xMean = Array(5.843, 3.057, 3.758, 1.199)
+ val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
+ val rdd = sc.parallelize(generateMultinomialLogisticInput(
+ weights, xMean, xVariance, true, nPoints, 42), 2)
+ val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features")
+ val numClasses = 3
+ val numIterations = 100
+ val layers = Array[Int](4, 5, 4, numClasses)
+ val trainer = new MultilayerPerceptronClassifier()
+ .setLayers(layers)
+ .setBlockSize(1)
+ .setSeed(11L)
+ .setMaxIter(numIterations)
+ val model = trainer.fit(dataFrame)
+ val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label")
+ .map { case Row(p: Double, l: Double) => (p, l) }
+ // train multinomial logistic regression
+ val lr = new LogisticRegressionWithLBFGS()
+ .setIntercept(true)
+ .setNumClasses(numClasses)
+ lr.optimizer.setRegParam(0.0)
+ .setNumIterations(numIterations)
+ val lrModel = lr.run(rdd)
+ val lrPredictionAndLabels = lrModel.predict(rdd.map(_.features)).zip(rdd.map(_.label))
+ // MLP's predictions should not differ a lot from LR's.
+ val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels)
+ val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels)
+ assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100)
+ }
+}