aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-11 08:52:28 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-11 08:52:28 -0700
commitb01b26260625f0ba14e5f3010207666d62d93864 (patch)
treeb3e891231cf80f750ff4016dd738f7c3266b8288 /mllib
parentb656e6134fc5cd27e1fe6b6ab30fd7633cab0b14 (diff)
downloadspark-b01b26260625f0ba14e5f3010207666d62d93864.tar.gz
spark-b01b26260625f0ba14e5f3010207666d62d93864.tar.bz2
spark-b01b26260625f0ba14e5f3010207666d62d93864.zip
[SPARK-9773] [ML] [PySpark] Add Python API for MultilayerPerceptronClassifier
Add Python API for ```MultilayerPerceptronClassifier```. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8067 from yanboliang/SPARK-9773.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala9
1 files changed, 9 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 82fc80c580..5f60dea91f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.classification
+import scala.collection.JavaConverters._
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed}
import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor}
@@ -182,6 +184,13 @@ class MultilayerPerceptronClassificationModel private[ml] (
private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights)
/**
+ * Returns layers in a Java List.
+ */
+ private[ml] def javaLayers: java.util.List[Int] = {
+ layers.toList.asJava
+ }
+
+ /**
* Predict label for the given features.
* This internal method is used to implement [[transform()]] and output [[predictionCol]].
*/