diff options
author | WeichenXu <WeichenXu123@outlook.com> | 2016-10-25 21:42:59 -0700 |
---|---|---|
committer | Felix Cheung <felixcheung@apache.org> | 2016-10-25 21:42:59 -0700 |
commit | 12b3e8d2e02788c3bebfecdd69755e94d80011c9 (patch) | |
tree | a8577ebadef6f612401fb7bd92d22d23f4a30ced /mllib | |
parent | c329a568b58d65c492a43926bf0f588f2ae6a66e (diff) | |
download | spark-12b3e8d2e02788c3bebfecdd69755e94d80011c9.tar.gz spark-12b3e8d2e02788c3bebfecdd69755e94d80011c9.tar.bz2 spark-12b3e8d2e02788c3bebfecdd69755e94d80011c9.zip |
[SPARK-18007][SPARKR][ML] update SparkR MLP - add initalWeights parameter
## What changes were proposed in this pull request?
update SparkR MLP, add initalWeights parameter.
## How was this patch tested?
test added.
Author: WeichenXu <WeichenXu123@outlook.com>
Closes #15552 from WeichenXu123/mlp_r_add_initialWeight_param.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala index 1067300353..2193eb80e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala @@ -24,6 +24,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier} +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter} import org.apache.spark.sql.{DataFrame, Dataset} @@ -58,7 +59,8 @@ private[r] object MultilayerPerceptronClassifierWrapper maxIter: Int, tol: Double, stepSize: Double, - seed: String + seed: String, + initialWeights: Array[Double] ): MultilayerPerceptronClassifierWrapper = { // get labels and feature names from output schema val schema = data.schema @@ -73,6 +75,11 @@ private[r] object MultilayerPerceptronClassifierWrapper .setStepSize(stepSize) .setPredictionCol(PREDICTED_LABEL_COL) if (seed != null && seed.length > 0) mlp.setSeed(seed.toInt) + if (initialWeights != null) { + require(initialWeights.length > 0) + mlp.setInitialWeights(Vectors.dense(initialWeights)) + } + val pipeline = new Pipeline() .setStages(Array(mlp)) .fit(data) |