aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-25 14:08:41 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-25 14:08:41 -0700
commit9cb3ba1013a7eae11be8a00fa4a9c5308bb20195 (patch)
treeeb275db612f3bc4f438aa426bb49c528d6fc0fe9
parent0c47e274ab8c286498fa002e2c92febcb53905c6 (diff)
downloadspark-9cb3ba1013a7eae11be8a00fa4a9c5308bb20195.tar.gz
spark-9cb3ba1013a7eae11be8a00fa4a9c5308bb20195.tar.bz2
spark-9cb3ba1013a7eae11be8a00fa4a9c5308bb20195.zip
[SPARK-14312][ML][SPARKR] NaiveBayes model persistence in SparkR
## What changes were proposed in this pull request? SparkR ```NaiveBayesModel``` supports ```save/load``` by the following API: ``` df <- createDataFrame(sqlContext, infert) model <- naiveBayes(education ~ ., df, laplace = 0) ml.save(model, path) model2 <- ml.load(path) ``` ## How was this patch tested? Add unit tests. cc mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes #12573 from yanboliang/spark-14312.
-rw-r--r--R/pkg/NAMESPACE6
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/R/mllib.R48
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R12
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala52
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala45
6 files changed, 162 insertions, 5 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 0f92b5e597..c0a63d6b3e 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -107,7 +107,8 @@ exportMethods("arrange",
"write.jdbc",
"write.json",
"write.parquet",
- "write.text")
+ "write.text",
+ "ml.save")
exportClasses("Column")
@@ -299,7 +300,8 @@ export("as.DataFrame",
"tableNames",
"tables",
"uncacheTable",
- "print.summary.GeneralizedLinearRegressionModel")
+ "print.summary.GeneralizedLinearRegressionModel",
+ "ml.load")
export("structField",
"structField.jobj",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 04274a12bc..f654d8330c 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1200,3 +1200,7 @@ setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBa
#' @rdname survreg
#' @export
setGeneric("survreg", function(formula, data, ...) { standardGeneric("survreg") })
+
+#' @rdname ml.save
+#' @export
+setGeneric("ml.save", function(object, path, ...) { standardGeneric("ml.save") })
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 7dd82963a1..cda6100e79 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -338,6 +338,54 @@ setMethod("naiveBayes", signature(formula = "formula", data = "SparkDataFrame"),
return(new("NaiveBayesModel", jobj = jobj))
})
+#' Save the Bernoulli naive Bayes model to the input path.
+#'
+#' @param object A fitted Bernoulli naive Bayes model
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' which means throw exception if the output path exists.
+#'
+#' @rdname ml.save
+#' @name ml.save
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- createDataFrame(sqlContext, infert)
+#' model <- naiveBayes(education ~ ., df, laplace = 0)
+#' path <- "path/to/model"
+#' ml.save(model, path)
+#' }
+setMethod("ml.save", signature(object = "NaiveBayesModel", path = "character"),
+ function(object, path, overwrite = FALSE) {
+ writer <- callJMethod(object@jobj, "write")
+ if (overwrite) {
+ writer <- callJMethod(writer, "overwrite")
+ }
+ invisible(callJMethod(writer, "save", path))
+ })
+
+#' Load a fitted MLlib model from the input path.
+#'
+#' @param path Path of the model to read.
+#' @return a fitted MLlib model
+#' @rdname ml.load
+#' @name ml.load
+#' @export
+#' @examples
+#' \dontrun{
+#' path <- "path/to/model"
+#' model <- ml.load(path)
+#' }
+ml.load <- function(path) {
+ path <- suppressWarnings(normalizePath(path))
+ jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path)
+ if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) {
+ return(new("NaiveBayesModel", jobj = jobj))
+ } else {
+ stop(paste("Unsupported model: ", jobj))
+ }
+}
+
#' Fit an accelerated failure time (AFT) survival regression model.
#'
#' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg().
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 1597306bb6..63ec84e497 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -204,6 +204,18 @@ test_that("naiveBayes", {
"Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No",
"Yes", "Yes", "No", "No"))
+ # Test model save/load
+ modelPath <- tempfile(pattern = "naiveBayes", fileext = ".tmp")
+ ml.save(m, modelPath)
+ expect_error(ml.save(m, modelPath))
+ ml.save(m, modelPath, overwrite = TRUE)
+ m2 <- ml.load(modelPath)
+ s2 <- summary(m2)
+ expect_equal(s$apriori, s2$apriori)
+ expect_equal(s$tables, s2$tables)
+
+ unlink(modelPath)
+
# Test e1071::naiveBayes
if (requireNamespace("e1071", quietly = TRUE)) {
expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error()))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index b17207e99b..27c7e72881 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -17,16 +17,23 @@
package org.apache.spark.ml.r
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.ml.feature.{IndexToString, RFormula}
+import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
private[r] class NaiveBayesWrapper private (
- pipeline: PipelineModel,
+ val pipeline: PipelineModel,
val labels: Array[String],
- val features: Array[String]) {
+ val features: Array[String]) extends MLWritable {
import NaiveBayesWrapper._
@@ -41,9 +48,11 @@ private[r] class NaiveBayesWrapper private (
.drop(PREDICTED_LABEL_INDEX_COL)
.drop(naiveBayesModel.getFeaturesCol)
}
+
+ override def write: MLWriter = new NaiveBayesWrapper.NaiveBayesWrapperWriter(this)
}
-private[r] object NaiveBayesWrapper {
+private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
val PREDICTED_LABEL_COL = "prediction"
@@ -74,4 +83,41 @@ private[r] object NaiveBayesWrapper {
.fit(data)
new NaiveBayesWrapper(pipeline, labels, features)
}
+
+ override def read: MLReader[NaiveBayesWrapper] = new NaiveBayesWrapperReader
+
+ override def load(path: String): NaiveBayesWrapper = super.load(path)
+
+ class NaiveBayesWrapperWriter(instance: NaiveBayesWrapper) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("labels" -> instance.labels.toSeq) ~
+ ("features" -> instance.features.toSeq)
+ val rMetadataJson: String = compact(render(rMetadata))
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+ instance.pipeline.save(pipelinePath)
+ }
+ }
+
+ class NaiveBayesWrapperReader extends MLReader[NaiveBayesWrapper] {
+
+ override def load(path: String): NaiveBayesWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val labels = (rMetadata \ "labels").extract[Array[String]]
+ val features = (rMetadata \ "features").extract[Array[String]]
+
+ val pipeline = PipelineModel.load(pipelinePath)
+ new NaiveBayesWrapper(pipeline, labels, features)
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
new file mode 100644
index 0000000000..7f6f147532
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s.DefaultFormats
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkException
+import org.apache.spark.ml.util.MLReader
+
+/**
+ * This is the Scala stub of SparkR ml.load. It will dispatch the call to corresponding
+ * model wrapper loading function according the class name extracted from rMetadata of the path.
+ */
+private[r] object RWrappers extends MLReader[Object] {
+
+ override def load(path: String): Object = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val className = (rMetadata \ "class").extract[String]
+ className match {
+ case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path)
+ case _ =>
+ throw new SparkException(s"SparkR ml.load does not support load $className")
+ }
+ }
+}