aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorJunyang Qian <junyangq@databricks.com>2016-08-19 14:24:09 -0700
committerXiangrui Meng <meng@databricks.com>2016-08-19 14:24:09 -0700
commitacac7a508a29d0f75d86ee2e4ca83ebf01a36cf8 (patch)
treebf01165da59ed904073196844195484318459d81 /R
parentcf0cce90364d17afe780ff9a5426dfcefa298535 (diff)
downloadspark-acac7a508a29d0f75d86ee2e4ca83ebf01a36cf8.tar.gz
spark-acac7a508a29d0f75d86ee2e4ca83ebf01a36cf8.tar.bz2
spark-acac7a508a29d0f75d86ee2e4ca83ebf01a36cf8.zip
[SPARK-16443][SPARKR] Alternating Least Squares (ALS) wrapper
## What changes were proposed in this pull request? Add Alternating Least Squares wrapper in SparkR. Unit tests have been updated. ## How was this patch tested? SparkR unit tests. (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) ![screen shot 2016-07-27 at 3 50 31 pm](https://cloud.githubusercontent.com/assets/15318264/17195347/f7a6352a-5411-11e6-8e21-61a48070192a.png) ![screen shot 2016-07-27 at 3 50 46 pm](https://cloud.githubusercontent.com/assets/15318264/17195348/f7a7d452-5411-11e6-845f-6d292283bc28.png) Author: Junyang Qian <junyangq@databricks.com> Closes #14384 from junyangq/SPARK-16443.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/NAMESPACE3
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/R/mllib.R159
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R40
4 files changed, 201 insertions, 5 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 4404cffc29..e1b87b28d3 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -29,7 +29,8 @@ exportMethods("glm",
"spark.posterior",
"spark.perplexity",
"spark.isoreg",
- "spark.gaussianMixture")
+ "spark.gaussianMixture",
+ "spark.als")
# Job group lifecycle management methods
export("setJobGroup",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index fe04bcfc7d..693aa31d3e 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1332,3 +1332,7 @@ setGeneric("spark.gaussianMixture",
#' @rdname write.ml
#' @export
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
+
+#' @rdname spark.als
+#' @export
+setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") })
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index b9527410a9..36f38fc73a 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -74,6 +74,13 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
#' @note GaussianMixtureModel since 2.1.0
setClass("GaussianMixtureModel", representation(jobj = "jobj"))
+#' S4 class that represents an ALSModel
+#'
+#' @param jobj a Java object reference to the backing Scala ALSWrapper
+#' @export
+#' @note ALSModel since 2.1.0
+setClass("ALSModel", representation(jobj = "jobj"))
+
#' Saves the MLlib model to the input path
#'
#' Saves the MLlib model to the input path. For more information, see the specific
@@ -82,8 +89,8 @@ setClass("GaussianMixtureModel", representation(jobj = "jobj"))
#' @name write.ml
#' @export
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
-#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.lda}
-#' @seealso \link{spark.isoreg}
+#' @seealso \link{spark.als}, \link{spark.kmeans}, \link{spark.lda}, \link{spark.naiveBayes}
+#' @seealso \link{spark.survreg}, \link{spark.isoreg}
#' @seealso \link{read.ml}
NULL
@@ -95,10 +102,11 @@ NULL
#' @name predict
#' @export
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
-#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
+#' @seealso \link{spark.als}, \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
#' @seealso \link{spark.isoreg}
NULL
+
#' Generalized Linear Models
#'
#' Fits generalized linear model against a Spark DataFrame.
@@ -801,6 +809,8 @@ read.ml <- function(path) {
return(new("IsotonicRegressionModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
return(new("GaussianMixtureModel", jobj = jobj))
+ } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) {
+ return(new("ALSModel", jobj = jobj))
} else {
stop(paste("Unsupported model: ", jobj))
}
@@ -1053,4 +1063,145 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
setMethod("predict", signature(object = "GaussianMixtureModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
- }) \ No newline at end of file
+ })
+
+#' Alternating Least Squares (ALS) for Collaborative Filtering
+#'
+#' \code{spark.als} learns latent factors in collaborative filtering via alternating least
+#' squares. Users can call \code{summary} to obtain fitted latent factors, \code{predict}
+#' to make predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models.
+#'
+#' For more details, see
+#' \href{http://spark.apache.org/docs/latest/ml-collaborative-filtering.html}{MLlib:
+#' Collaborative Filtering}.
+#'
+#' @param data a SparkDataFrame for training.
+#' @param ratingCol column name for ratings.
+#' @param userCol column name for user ids. Ids must be (or can be coerced into) integers.
+#' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers.
+#' @param rank rank of the matrix factorization (> 0).
+#' @param reg regularization parameter (>= 0).
+#' @param maxIter maximum number of iterations (>= 0).
+#' @param nonnegative logical value indicating whether to apply nonnegativity constraints.
+#' @param implicitPrefs logical value indicating whether to use implicit preference.
+#' @param alpha alpha parameter in the implicit preference formulation (>= 0).
+#' @param seed integer seed for random number generation.
+#' @param numUserBlocks number of user blocks used to parallelize computation (> 0).
+#' @param numItemBlocks number of item blocks used to parallelize computation (> 0).
+#' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1).
+#'
+#' @return \code{spark.als} returns a fitted ALS model
+#' @rdname spark.als
+#' @aliases spark.als,SparkDataFrame-method
+#' @name spark.als
+#' @export
+#' @examples
+#' \dontrun{
+#' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0),
+#' list(2, 1, 1.0), list(2, 2, 5.0))
+#' df <- createDataFrame(ratings, c("user", "item", "rating"))
+#' model <- spark.als(df, "rating", "user", "item")
+#'
+#' # extract latent factors
+#' stats <- summary(model)
+#' userFactors <- stats$userFactors
+#' itemFactors <- stats$itemFactors
+#'
+#' # make predictions
+#' predicted <- predict(model, df)
+#' showDF(predicted)
+#'
+#' # save and load the model
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#'
+#' # set other arguments
+#' modelS <- spark.als(df, "rating", "user", "item", rank = 20,
+#' reg = 0.1, nonnegative = TRUE)
+#' statsS <- summary(modelS)
+#' }
+#' @note spark.als since 2.1.0
+setMethod("spark.als", signature(data = "SparkDataFrame"),
+ function(data, ratingCol = "rating", userCol = "user", itemCol = "item",
+ rank = 10, reg = 1.0, maxIter = 10, nonnegative = FALSE,
+ implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, numItemBlocks = 10,
+ checkpointInterval = 10, seed = 0) {
+
+ if (!is.numeric(rank) || rank <= 0) {
+ stop("rank should be a positive number.")
+ }
+ if (!is.numeric(reg) || reg < 0) {
+ stop("reg should be a nonnegative number.")
+ }
+ if (!is.numeric(maxIter) || maxIter <= 0) {
+ stop("maxIter should be a positive number.")
+ }
+
+ jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper",
+ "fit", data@sdf, ratingCol, userCol, itemCol, as.integer(rank),
+ reg, as.integer(maxIter), implicitPrefs, alpha, nonnegative,
+ as.integer(numUserBlocks), as.integer(numItemBlocks),
+ as.integer(checkpointInterval), as.integer(seed))
+ return(new("ALSModel", jobj = jobj))
+ })
+
+# Returns a summary of the ALS model produced by spark.als.
+
+#' @param object a fitted ALS model.
+#' @return \code{summary} returns a list containing the names of the user column,
+#' the item column and the rating column, the estimated user and item factors,
+#' rank, regularization parameter and maximum number of iterations used in training.
+#' @rdname spark.als
+#' @aliases summary,ALSModel-method
+#' @export
+#' @note summary(ALSModel) since 2.1.0
+setMethod("summary", signature(object = "ALSModel"),
+function(object, ...) {
+ jobj <- object@jobj
+ user <- callJMethod(jobj, "userCol")
+ item <- callJMethod(jobj, "itemCol")
+ rating <- callJMethod(jobj, "ratingCol")
+ userFactors <- dataFrame(callJMethod(jobj, "userFactors"))
+ itemFactors <- dataFrame(callJMethod(jobj, "itemFactors"))
+ rank <- callJMethod(jobj, "rank")
+ return(list(user = user, item = item, rating = rating, userFactors = userFactors,
+ itemFactors = itemFactors, rank = rank))
+})
+
+
+# Makes predictions from an ALS model or a model produced by spark.als.
+
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns a SparkDataFrame containing predicted values.
+#' @rdname spark.als
+#' @aliases predict,ALSModel-method
+#' @export
+#' @note predict(ALSModel) since 2.1.0
+setMethod("predict", signature(object = "ALSModel"),
+function(object, newData) {
+ return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+})
+
+
+# Saves the ALS model to the input path.
+
+#' @param path the directory where the model is saved.
+#' @param overwrite logical value indicating whether to overwrite if the output path
+#' already exists. Default is FALSE which means throw exception
+#' if the output path exists.
+#'
+#' @rdname spark.als
+#' @aliases write.ml,ALSModel,character-method
+#' @export
+#' @seealso \link{read.ml}
+#' @note write.ml(ALSModel, character) since 2.1.0
+setMethod("write.ml", signature(object = "ALSModel", path = "character"),
+function(object, path, overwrite = FALSE) {
+ writer <- callJMethod(object@jobj, "write")
+ if (overwrite) {
+ writer <- callJMethod(writer, "overwrite")
+ }
+ invisible(callJMethod(writer, "save", path))
+})
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index dfb7a185cd..67a3099101 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -657,4 +657,44 @@ test_that("spark.posterior and spark.perplexity", {
expect_equal(length(local.posterior), sum(unlist(local.posterior)))
})
+test_that("spark.als", {
+ data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0),
+ list(2, 1, 1.0), list(2, 2, 5.0))
+ df <- createDataFrame(data, c("user", "item", "score"))
+ model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item",
+ rank = 10, maxIter = 5, seed = 0, reg = 0.1)
+ stats <- summary(model)
+ expect_equal(stats$rank, 10)
+ test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item"))
+ predictions <- collect(predict(model, test))
+
+ expect_equal(predictions$prediction, c(-0.1380762, 2.6258414, -1.5018409),
+ tolerance = 1e-4)
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+ expect_equal(stats2$rating, "score")
+ userFactors <- collect(stats$userFactors)
+ itemFactors <- collect(stats$itemFactors)
+ userFactors2 <- collect(stats2$userFactors)
+ itemFactors2 <- collect(stats2$itemFactors)
+
+ orderUser <- order(userFactors$id)
+ orderUser2 <- order(userFactors2$id)
+ expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2])
+ expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2])
+
+ orderItem <- order(itemFactors$id)
+ orderItem2 <- order(itemFactors2$id)
+ expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2])
+ expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2])
+
+ unlink(modelPath)
+})
+
sparkR.session.stop()