aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala9
1 files changed, 5 insertions, 4 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
index 1292e57d7c..dc91fc5f9e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
@@ -42,7 +42,7 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext {
val dataSample = rddData.first()
val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
- val initialWeights = FeedForwardModel(topology, 23124).weights()
+ val initialWeights = FeedForwardModel(topology, 23124).weights
val trainer = new FeedForwardTrainer(topology, 2, 1)
trainer.setWeights(initialWeights)
trainer.LBFGSOptimizer.setNumIterations(20)
@@ -76,10 +76,11 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext {
val dataSample = rddData.first()
val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
- val initialWeights = FeedForwardModel(topology, 23124).weights()
+ val initialWeights = FeedForwardModel(topology, 23124).weights
val trainer = new FeedForwardTrainer(topology, 2, 2)
- trainer.SGDOptimizer.setNumIterations(2000)
- trainer.setWeights(initialWeights)
+ // TODO: add a test for SGD
+ trainer.LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(20)
+ trainer.setWeights(initialWeights).setStackSize(1)
val model = trainer.train(rddData)
val predictionAndLabels = rddData.map { case (input, label) =>
(model.predict(input), label)