aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorWeichenXu <WeichenXu123@outlook.com>2016-10-25 21:42:59 -0700
committerFelix Cheung <felixcheung@apache.org>2016-10-25 21:42:59 -0700
commit12b3e8d2e02788c3bebfecdd69755e94d80011c9 (patch)
treea8577ebadef6f612401fb7bd92d22d23f4a30ced /mllib
parentc329a568b58d65c492a43926bf0f588f2ae6a66e (diff)
downloadspark-12b3e8d2e02788c3bebfecdd69755e94d80011c9.tar.gz
spark-12b3e8d2e02788c3bebfecdd69755e94d80011c9.tar.bz2
spark-12b3e8d2e02788c3bebfecdd69755e94d80011c9.zip
[SPARK-18007][SPARKR][ML] update SparkR MLP - add initalWeights parameter
## What changes were proposed in this pull request? update SparkR MLP, add initalWeights parameter. ## How was this patch tested? test added. Author: WeichenXu <WeichenXu123@outlook.com> Closes #15552 from WeichenXu123/mlp_r_add_initialWeight_param.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala9
1 files changed, 8 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
index 1067300353..2193eb80e9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala
@@ -24,6 +24,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier}
+import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter}
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -58,7 +59,8 @@ private[r] object MultilayerPerceptronClassifierWrapper
maxIter: Int,
tol: Double,
stepSize: Double,
- seed: String
+ seed: String,
+ initialWeights: Array[Double]
): MultilayerPerceptronClassifierWrapper = {
// get labels and feature names from output schema
val schema = data.schema
@@ -73,6 +75,11 @@ private[r] object MultilayerPerceptronClassifierWrapper
.setStepSize(stepSize)
.setPredictionCol(PREDICTED_LABEL_COL)
if (seed != null && seed.length > 0) mlp.setSeed(seed.toInt)
+ if (initialWeights != null) {
+ require(initialWeights.length > 0)
+ mlp.setInitialWeights(Vectors.dense(initialWeights))
+ }
+
val pipeline = new Pipeline()
.setStages(Array(mlp))
.fit(data)