aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--R/pkg/NAMESPACE3
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/R/mllib.R159
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R40
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala119
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala2
6 files changed, 322 insertions, 5 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 4404cffc29..e1b87b28d3 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -29,7 +29,8 @@ exportMethods("glm",
"spark.posterior",
"spark.perplexity",
"spark.isoreg",
- "spark.gaussianMixture")
+ "spark.gaussianMixture",
+ "spark.als")
# Job group lifecycle management methods
export("setJobGroup",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index fe04bcfc7d..693aa31d3e 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1332,3 +1332,7 @@ setGeneric("spark.gaussianMixture",
#' @rdname write.ml
#' @export
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
+
+#' @rdname spark.als
+#' @export
+setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") })
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index b9527410a9..36f38fc73a 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -74,6 +74,13 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
#' @note GaussianMixtureModel since 2.1.0
setClass("GaussianMixtureModel", representation(jobj = "jobj"))
+#' S4 class that represents an ALSModel
+#'
+#' @param jobj a Java object reference to the backing Scala ALSWrapper
+#' @export
+#' @note ALSModel since 2.1.0
+setClass("ALSModel", representation(jobj = "jobj"))
+
#' Saves the MLlib model to the input path
#'
#' Saves the MLlib model to the input path. For more information, see the specific
@@ -82,8 +89,8 @@ setClass("GaussianMixtureModel", representation(jobj = "jobj"))
#' @name write.ml
#' @export
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
-#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.lda}
-#' @seealso \link{spark.isoreg}
+#' @seealso \link{spark.als}, \link{spark.kmeans}, \link{spark.lda}, \link{spark.naiveBayes}
+#' @seealso \link{spark.survreg}, \link{spark.isoreg}
#' @seealso \link{read.ml}
NULL
@@ -95,10 +102,11 @@ NULL
#' @name predict
#' @export
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
-#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
+#' @seealso \link{spark.als}, \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
#' @seealso \link{spark.isoreg}
NULL
+
#' Generalized Linear Models
#'
#' Fits generalized linear model against a Spark DataFrame.
@@ -801,6 +809,8 @@ read.ml <- function(path) {
return(new("IsotonicRegressionModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
return(new("GaussianMixtureModel", jobj = jobj))
+ } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) {
+ return(new("ALSModel", jobj = jobj))
} else {
stop(paste("Unsupported model: ", jobj))
}
@@ -1053,4 +1063,145 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
setMethod("predict", signature(object = "GaussianMixtureModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
- }) \ No newline at end of file
+ })
+
+#' Alternating Least Squares (ALS) for Collaborative Filtering
+#'
+#' \code{spark.als} learns latent factors in collaborative filtering via alternating least
+#' squares. Users can call \code{summary} to obtain fitted latent factors, \code{predict}
+#' to make predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
+#'
+#' For more details, see
+#' \href{http://spark.apache.org/docs/latest/ml-collaborative-filtering.html}{MLlib:
+#' Collaborative Filtering}.
+#'
+#' @param data a SparkDataFrame for training.
+#' @param ratingCol column name for ratings.
+#' @param userCol column name for user ids. Ids must be (or can be coerced into) integers.
+#' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers.
+#' @param rank rank of the matrix factorization (> 0).
+#' @param reg regularization parameter (>= 0).
+#' @param maxIter maximum number of iterations (>= 0).
+#' @param nonnegative logical value indicating whether to apply nonnegativity constraints.
+#' @param implicitPrefs logical value indicating whether to use implicit preference.
+#' @param alpha alpha parameter in the implicit preference formulation (>= 0).
+#' @param seed integer seed for random number generation.
+#' @param numUserBlocks number of user blocks used to parallelize computation (> 0).
+#' @param numItemBlocks number of item blocks used to parallelize computation (> 0).
+#' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1).
+#'
+#' @return \code{spark.als} returns a fitted ALS model
+#' @rdname spark.als
+#' @aliases spark.als,SparkDataFrame-method
+#' @name spark.als
+#' @export
+#' @examples
+#' \dontrun{
+#' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0),
+#' list(2, 1, 1.0), list(2, 2, 5.0))
+#' df <- createDataFrame(ratings, c("user", "item", "rating"))
+#' model <- spark.als(df, "rating", "user", "item")
+#'
+#' # extract latent factors
+#' stats <- summary(model)
+#' userFactors <- stats$userFactors
+#' itemFactors <- stats$itemFactors
+#'
+#' # make predictions
+#' predicted <- predict(model, df)
+#' showDF(predicted)
+#'
+#' # save and load the model
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#'
+#' # set other arguments
+#' modelS <- spark.als(df, "rating", "user", "item", rank = 20,
+#' reg = 0.1, nonnegative = TRUE)
+#' statsS <- summary(modelS)
+#' }
+#' @note spark.als since 2.1.0
+setMethod("spark.als", signature(data = "SparkDataFrame"),
+ function(data, ratingCol = "rating", userCol = "user", itemCol = "item",
+ rank = 10, reg = 1.0, maxIter = 10, nonnegative = FALSE,
+ implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, numItemBlocks = 10,
+ checkpointInterval = 10, seed = 0) {
+
+ if (!is.numeric(rank) || rank <= 0) {
+ stop("rank should be a positive number.")
+ }
+ if (!is.numeric(reg) || reg < 0) {
+ stop("reg should be a nonnegative number.")
+ }
+ if (!is.numeric(maxIter) || maxIter <= 0) {
+ stop("maxIter should be a positive number.")
+ }
+
+ jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper",
+ "fit", data@sdf, ratingCol, userCol, itemCol, as.integer(rank),
+ reg, as.integer(maxIter), implicitPrefs, alpha, nonnegative,
+ as.integer(numUserBlocks), as.integer(numItemBlocks),
+ as.integer(checkpointInterval), as.integer(seed))
+ return(new("ALSModel", jobj = jobj))
+ })
+
+# Returns a summary of the ALS model produced by spark.als.
+
+#' @param object a fitted ALS model.
+#' @return \code{summary} returns a list containing the names of the user column,
+#' the item column and the rating column, the estimated user and item factors,
+#' rank, regularization parameter and maximum number of iterations used in training.
+#' @rdname spark.als
+#' @aliases summary,ALSModel-method
+#' @export
+#' @note summary(ALSModel) since 2.1.0
+setMethod("summary", signature(object = "ALSModel"),
+function(object, ...) {
+ jobj <- object@jobj
+ user <- callJMethod(jobj, "userCol")
+ item <- callJMethod(jobj, "itemCol")
+ rating <- callJMethod(jobj, "ratingCol")
+ userFactors <- dataFrame(callJMethod(jobj, "userFactors"))
+ itemFactors <- dataFrame(callJMethod(jobj, "itemFactors"))
+ rank <- callJMethod(jobj, "rank")
+ return(list(user = user, item = item, rating = rating, userFactors = userFactors,
+ itemFactors = itemFactors, rank = rank))
+})
+
+
+# Makes predictions from an ALS model or a model produced by spark.als.
+
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns a SparkDataFrame containing predicted values.
+#' @rdname spark.als
+#' @aliases predict,ALSModel-method
+#' @export
+#' @note predict(ALSModel) since 2.1.0
+setMethod("predict", signature(object = "ALSModel"),
+function(object, newData) {
+ return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+})
+
+
+# Saves the ALS model to the input path.
+
+#' @param path the directory where the model is saved.
+#' @param overwrite logical value indicating whether to overwrite if the output path
+#' already exists. Default is FALSE which means throw exception
+#' if the output path exists.
+#'
+#' @rdname spark.als
+#' @aliases write.ml,ALSModel,character-method
+#' @export
+#' @seealso \link{read.ml}
+#' @note write.ml(ALSModel, character) since 2.1.0
+setMethod("write.ml", signature(object = "ALSModel", path = "character"),
+function(object, path, overwrite = FALSE) {
+ writer <- callJMethod(object@jobj, "write")
+ if (overwrite) {
+ writer <- callJMethod(writer, "overwrite")
+ }
+ invisible(callJMethod(writer, "save", path))
+})
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index dfb7a185cd..67a3099101 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -657,4 +657,44 @@ test_that("spark.posterior and spark.perplexity", {
expect_equal(length(local.posterior), sum(unlist(local.posterior)))
})
+test_that("spark.als", {
+ data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0),
+ list(2, 1, 1.0), list(2, 2, 5.0))
+ df <- createDataFrame(data, c("user", "item", "score"))
+ model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item",
+ rank = 10, maxIter = 5, seed = 0, reg = 0.1)
+ stats <- summary(model)
+ expect_equal(stats$rank, 10)
+ test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item"))
+ predictions <- collect(predict(model, test))
+
+ expect_equal(predictions$prediction, c(-0.1380762, 2.6258414, -1.5018409),
+ tolerance = 1e-4)
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats2$rating, "score")
+ userFactors <- collect(stats$userFactors)
+ itemFactors <- collect(stats$itemFactors)
+ userFactors2 <- collect(stats2$userFactors)
+ itemFactors2 <- collect(stats2$itemFactors)
+
+ orderUser <- order(userFactors$id)
+ orderUser2 <- order(userFactors2$id)
+ expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2])
+ expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2])
+
+ orderItem <- order(itemFactors$id)
+ orderItem2 <- order(itemFactors2$id)
+ expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2])
+ expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2])
+
+ unlink(modelPath)
+})
+
sparkR.session.stop()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala
new file mode 100644
index 0000000000..ad13cced46
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.recommendation.{ALS, ALSModel}
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class ALSWrapper private (
+ val alsModel: ALSModel,
+ val ratingCol: String) extends MLWritable {
+
+ lazy val userCol: String = alsModel.getUserCol
+ lazy val itemCol: String = alsModel.getItemCol
+ lazy val userFactors: DataFrame = alsModel.userFactors
+ lazy val itemFactors: DataFrame = alsModel.itemFactors
+ lazy val rank: Int = alsModel.rank
+
+ def transform(dataset: Dataset[_]): DataFrame = {
+ alsModel.transform(dataset)
+ }
+
+ override def write: MLWriter = new ALSWrapper.ALSWrapperWriter(this)
+}
+
+private[r] object ALSWrapper extends MLReadable[ALSWrapper] {
+
+ def fit( // scalastyle:ignore
+ data: DataFrame,
+ ratingCol: String,
+ userCol: String,
+ itemCol: String,
+ rank: Int,
+ regParam: Double,
+ maxIter: Int,
+ implicitPrefs: Boolean,
+ alpha: Double,
+ nonnegative: Boolean,
+ numUserBlocks: Int,
+ numItemBlocks: Int,
+ checkpointInterval: Int,
+ seed: Int): ALSWrapper = {
+
+ val als = new ALS()
+ .setRatingCol(ratingCol)
+ .setUserCol(userCol)
+ .setItemCol(itemCol)
+ .setRank(rank)
+ .setRegParam(regParam)
+ .setMaxIter(maxIter)
+ .setImplicitPrefs(implicitPrefs)
+ .setAlpha(alpha)
+ .setNonnegative(nonnegative)
+ .setNumBlocks(numUserBlocks)
+ .setNumItemBlocks(numItemBlocks)
+ .setCheckpointInterval(checkpointInterval)
+ .setSeed(seed.toLong)
+
+ val alsModel: ALSModel = als.fit(data)
+
+ new ALSWrapper(alsModel, ratingCol)
+ }
+
+ override def read: MLReader[ALSWrapper] = new ALSWrapperReader
+
+ override def load(path: String): ALSWrapper = super.load(path)
+
+ class ALSWrapperWriter(instance: ALSWrapper) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val modelPath = new Path(path, "model").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("ratingCol" -> instance.ratingCol)
+ val rMetadataJson: String = compact(render(rMetadata))
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+ instance.alsModel.save(modelPath)
+ }
+ }
+
+ class ALSWrapperReader extends MLReader[ALSWrapper] {
+
+ override def load(path: String): ALSWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val modelPath = new Path(path, "model").toString
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val ratingCol = (rMetadata \ "ratingCol").extract[String]
+ val alsModel = ALSModel.load(modelPath)
+
+ new ALSWrapper(alsModel, ratingCol)
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index e23af51df5..51a65f7fc4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -50,6 +50,8 @@ private[r] object RWrappers extends MLReader[Object] {
IsotonicRegressionWrapper.load(path)
case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>
GaussianMixtureWrapper.load(path)
+ case "org.apache.spark.ml.r.ALSWrapper" =>
+ ALSWrapper.load(path)
case _ =>
throw new SparkException(s"SparkR read.ml does not support load $className")
}