aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-25 14:08:41 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-25 14:08:41 -0700
commit9cb3ba1013a7eae11be8a00fa4a9c5308bb20195 (patch)
treeeb275db612f3bc4f438aa426bb49c528d6fc0fe9 /mllib
parent0c47e274ab8c286498fa002e2c92febcb53905c6 (diff)
downloadspark-9cb3ba1013a7eae11be8a00fa4a9c5308bb20195.tar.gz
spark-9cb3ba1013a7eae11be8a00fa4a9c5308bb20195.tar.bz2
spark-9cb3ba1013a7eae11be8a00fa4a9c5308bb20195.zip
[SPARK-14312][ML][SPARKR] NaiveBayes model persistence in SparkR
## What changes were proposed in this pull request? SparkR ```NaiveBayesModel``` supports ```save/load``` by the following API: ``` df <- createDataFrame(sqlContext, infert) model <- naiveBayes(education ~ ., df, laplace = 0) ml.save(model, path) model2 <- ml.load(path) ``` ## How was this patch tested? Add unit tests. cc mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes #12573 from yanboliang/spark-14312.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala52
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala45
2 files changed, 94 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index b17207e99b..27c7e72881 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -17,16 +17,23 @@
package org.apache.spark.ml.r
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.ml.feature.{IndexToString, RFormula}
+import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
private[r] class NaiveBayesWrapper private (
- pipeline: PipelineModel,
+ val pipeline: PipelineModel,
val labels: Array[String],
- val features: Array[String]) {
+ val features: Array[String]) extends MLWritable {
import NaiveBayesWrapper._
@@ -41,9 +48,11 @@ private[r] class NaiveBayesWrapper private (
.drop(PREDICTED_LABEL_INDEX_COL)
.drop(naiveBayesModel.getFeaturesCol)
}
+
+ override def write: MLWriter = new NaiveBayesWrapper.NaiveBayesWrapperWriter(this)
}
-private[r] object NaiveBayesWrapper {
+private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
val PREDICTED_LABEL_COL = "prediction"
@@ -74,4 +83,41 @@ private[r] object NaiveBayesWrapper {
.fit(data)
new NaiveBayesWrapper(pipeline, labels, features)
}
+
+ override def read: MLReader[NaiveBayesWrapper] = new NaiveBayesWrapperReader
+
+ override def load(path: String): NaiveBayesWrapper = super.load(path)
+
+ class NaiveBayesWrapperWriter(instance: NaiveBayesWrapper) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("labels" -> instance.labels.toSeq) ~
+ ("features" -> instance.features.toSeq)
+ val rMetadataJson: String = compact(render(rMetadata))
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+ instance.pipeline.save(pipelinePath)
+ }
+ }
+
+ class NaiveBayesWrapperReader extends MLReader[NaiveBayesWrapper] {
+
+ override def load(path: String): NaiveBayesWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val labels = (rMetadata \ "labels").extract[Array[String]]
+ val features = (rMetadata \ "features").extract[Array[String]]
+
+ val pipeline = PipelineModel.load(pipelinePath)
+ new NaiveBayesWrapper(pipeline, labels, features)
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
new file mode 100644
index 0000000000..7f6f147532
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s.DefaultFormats
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkException
+import org.apache.spark.ml.util.MLReader
+
+/**
+ * This is the Scala stub of SparkR ml.load. It will dispatch the call to corresponding
+ * model wrapper loading function according the class name extracted from rMetadata of the path.
+ */
+private[r] object RWrappers extends MLReader[Object] {
+
+ override def load(path: String): Object = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val className = (rMetadata \ "class").extract[String]
+ className match {
+ case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path)
+ case _ =>
+ throw new SparkException(s"SparkR ml.load does not support load $className")
+ }
+ }
+}