diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 82fc80c580..5f60dea91f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.classification +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} @@ -182,6 +184,13 @@ class MultilayerPerceptronClassificationModel private[ml] ( private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) /** + * Returns layers in a Java List. + */ + private[ml] def javaLayers: java.util.List[Int] = { + layers.toList.asJava + } + + /** * Predict label for the given features. * This internal method is used to implement [[transform()]] and output [[predictionCol]]. */ |