diff options
author | Xin Ren <iamshrek@126.com> | 2016-08-24 11:18:10 -0700 |
---|---|---|
committer | Felix Cheung <felixcheung@apache.org> | 2016-08-24 11:18:10 -0700 |
commit | 2fbdb606392631b1dff88ec86f388cc2559c28f5 (patch) | |
tree | 002050c92864378c0c65a5d6c449420c8d604170 /mllib | |
parent | d2932a0e987132c694ed59515b7c77adaad052e6 (diff) | |
download | spark-2fbdb606392631b1dff88ec86f388cc2559c28f5.tar.gz spark-2fbdb606392631b1dff88ec86f388cc2559c28f5.tar.bz2 spark-2fbdb606392631b1dff88ec86f388cc2559c28f5.zip |
[SPARK-16445][MLLIB][SPARKR] Multilayer Perceptron Classifier wrapper in SparkR
https://issues.apache.org/jira/browse/SPARK-16445
## What changes were proposed in this pull request?
Create Multilayer Perceptron Classifier wrapper in SparkR
## How was this patch tested?
Tested manually on local machine
Author: Xin Ren <iamshrek@126.com>
Closes #14447 from keypointt/SPARK-16445.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala | 134 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala | 2 |
2 files changed, 136 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala new file mode 100644 index 0000000000..be51e74187 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier} +import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter} +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class MultilayerPerceptronClassifierWrapper private ( + val pipeline: PipelineModel, + val labelCount: Long, + val layers: Array[Int], + val weights: Array[Double] + ) extends MLWritable { + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset) + } + + /** + * Returns an [[MLWriter]] instance for this ML instance. + */ + override def write: MLWriter = + new MultilayerPerceptronClassifierWrapper.MultilayerPerceptronClassifierWrapperWriter(this) +} + +private[r] object MultilayerPerceptronClassifierWrapper + extends MLReadable[MultilayerPerceptronClassifierWrapper] { + + val PREDICTED_LABEL_COL = "prediction" + + def fit( + data: DataFrame, + blockSize: Int, + layers: Array[Double], + solver: String, + maxIter: Int, + tol: Double, + stepSize: Double, + seed: Int + ): MultilayerPerceptronClassifierWrapper = { + // get labels and feature names from output schema + val schema = data.schema + + // assemble and fit the pipeline + val mlp = new MultilayerPerceptronClassifier() + .setLayers(layers.map(_.toInt)) + .setBlockSize(blockSize) + .setSolver(solver) + .setMaxIter(maxIter) + .setTol(tol) + .setStepSize(stepSize) + .setSeed(seed) + .setPredictionCol(PREDICTED_LABEL_COL) + val pipeline = new Pipeline() + .setStages(Array(mlp)) + .fit(data) + + val multilayerPerceptronClassificationModel: MultilayerPerceptronClassificationModel = + pipeline.stages.head.asInstanceOf[MultilayerPerceptronClassificationModel] + + val weights = multilayerPerceptronClassificationModel.weights.toArray + val layersFromPipeline = multilayerPerceptronClassificationModel.layers + val labelCount = data.select("label").distinct().count() + + new MultilayerPerceptronClassifierWrapper(pipeline, labelCount, layersFromPipeline, weights) + } + + /** + * Returns an [[MLReader]] instance for this class. + */ + override def read: MLReader[MultilayerPerceptronClassifierWrapper] = + new MultilayerPerceptronClassifierWrapperReader + + override def load(path: String): MultilayerPerceptronClassifierWrapper = super.load(path) + + class MultilayerPerceptronClassifierWrapperReader + extends MLReader[MultilayerPerceptronClassifierWrapper]{ + + override def load(path: String): MultilayerPerceptronClassifierWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val labelCount = (rMetadata \ "labelCount").extract[Long] + val layers = (rMetadata \ "layers").extract[Array[Int]] + val weights = (rMetadata \ "weights").extract[Array[Double]] + + val pipeline = PipelineModel.load(pipelinePath) + new MultilayerPerceptronClassifierWrapper(pipeline, labelCount, layers, weights) + } + } + + class MultilayerPerceptronClassifierWrapperWriter(instance: MultilayerPerceptronClassifierWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("labelCount" -> instance.labelCount) ~ + ("layers" -> instance.layers.toSeq) ~ + ("weights" -> instance.weights.toArray.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + + instance.pipeline.save(pipelinePath) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 51a65f7fc4..d64de1b6ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -44,6 +44,8 @@ private[r] object RWrappers extends MLReader[Object] { GeneralizedLinearRegressionWrapper.load(path) case "org.apache.spark.ml.r.KMeansWrapper" => KMeansWrapper.load(path) + case "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper" => + MultilayerPerceptronClassifierWrapper.load(path) case "org.apache.spark.ml.r.LDAWrapper" => LDAWrapper.load(path) case "org.apache.spark.ml.r.IsotonicRegressionWrapper" => |