aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala9
1 files changed, 9 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 82fc80c580..5f60dea91f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.classification
+import scala.collection.JavaConverters._
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed}
import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor}
@@ -182,6 +184,13 @@ class MultilayerPerceptronClassificationModel private[ml] (
private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights)
/**
+ * Returns layers in a Java List.
+ */
+ private[ml] def javaLayers: java.util.List[Int] = {
+ layers.toList.asJava
+ }
+
+ /**
* Predict label for the given features.
* This internal method is used to implement [[transform()]] and output [[predictionCol]].
*/