aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--R/pkg/DESCRIPTION1
-rw-r--r--R/pkg/NAMESPACE5
-rw-r--r--R/pkg/R/generics.R12
-rw-r--r--R/pkg/R/mllib_fpm.R158
-rw-r--r--R/pkg/R/mllib_utils.R2
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib_fpm.R83
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala86
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala2
8 files changed, 348 insertions, 1 deletions
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index 00dde64324..f475ee8770 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -44,6 +44,7 @@ Collate:
'jvm.R'
'mllib_classification.R'
'mllib_clustering.R'
+ 'mllib_fpm.R'
'mllib_recommendation.R'
'mllib_regression.R'
'mllib_stat.R'
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index c02046c94b..9b7e95ce30 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -66,7 +66,10 @@ exportMethods("glm",
"spark.randomForest",
"spark.gbt",
"spark.bisectingKmeans",
- "spark.svmLinear")
+ "spark.svmLinear",
+ "spark.fpGrowth",
+ "spark.freqItemsets",
+ "spark.associationRules")
# Job group lifecycle management methods
export("setJobGroup",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 80283e48ce..945676c7f1 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1445,6 +1445,18 @@ setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark
#' @export
setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") })
+#' @rdname spark.fpGrowth
+#' @export
+setGeneric("spark.fpGrowth", function(data, ...) { standardGeneric("spark.fpGrowth") })
+
+#' @rdname spark.fpGrowth
+#' @export
+setGeneric("spark.freqItemsets", function(object) { standardGeneric("spark.freqItemsets") })
+
+#' @rdname spark.fpGrowth
+#' @export
+setGeneric("spark.associationRules", function(object) { standardGeneric("spark.associationRules") })
+
#' @param object a fitted ML model object.
#' @param path the directory where the model is saved.
#' @param ... additional argument(s) passed to the method.
diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R
new file mode 100644
index 0000000000..96251b2c7c
--- /dev/null
+++ b/R/pkg/R/mllib_fpm.R
@@ -0,0 +1,158 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# mllib_fpm.R: Provides methods for MLlib frequent pattern mining algorithms integration
+
+#' S4 class that represents a FPGrowthModel
+#'
+#' @param jobj a Java object reference to the backing Scala FPGrowthModel
+#' @export
+#' @note FPGrowthModel since 2.2.0
+setClass("FPGrowthModel", slots = list(jobj = "jobj"))
+
+#' FP-growth
+#'
+#' A parallel FP-growth algorithm to mine frequent itemsets.
+#' For more details, see
+#' \href{https://spark.apache.org/docs/latest/mllib-frequent-pattern-mining.html#fp-growth}{
+#' FP-growth}.
+#'
+#' @param data A SparkDataFrame for training.
+#' @param minSupport Minimal support level.
+#' @param minConfidence Minimal confidence level.
+#' @param itemsCol Features column name.
+#' @param numPartitions Number of partitions used for fitting.
+#' @param ... additional argument(s) passed to the method.
+#' @return \code{spark.fpGrowth} returns a fitted FPGrowth model.
+#' @rdname spark.fpGrowth
+#' @name spark.fpGrowth
+#' @aliases spark.fpGrowth,SparkDataFrame-method
+#' @export
+#' @examples
+#' \dontrun{
+#' raw_data <- read.df(
+#' "data/mllib/sample_fpgrowth.txt",
+#' source = "csv",
+#' schema = structType(structField("raw_items", "string")))
+#'
+#' data <- selectExpr(raw_data, "split(raw_items, ' ') as items")
+#' model <- spark.fpGrowth(data)
+#'
+#' # Show frequent itemsets
+#' frequent_itemsets <- spark.freqItemsets(model)
+#' showDF(frequent_itemsets)
+#'
+#' # Show association rules
+#' association_rules <- spark.associationRules(model)
+#' showDF(association_rules)
+#'
+#' # Predict on new data
+#' new_itemsets <- data.frame(items = c("t", "t,s"))
+#' new_data <- selectExpr(createDataFrame(new_itemsets), "split(items, ',') as items")
+#' predict(model, new_data)
+#'
+#' # Save and load model
+#' path <- "/path/to/model"
+#' write.ml(model, path)
+#' read.ml(path)
+#'
+#' # Optional arguments
+#' baskets_data <- selectExpr(createDataFrame(itemsets), "split(items, ',') as baskets")
+#' another_model <- spark.fpGrowth(data, minSupport = 0.1, minConfidence = 0.5,
+#' itemsCol = "baskets", numPartitions = 10)
+#' }
+#' @note spark.fpGrowth since 2.2.0
+setMethod("spark.fpGrowth", signature(data = "SparkDataFrame"),
+ function(data, minSupport = 0.3, minConfidence = 0.8,
+ itemsCol = "items", numPartitions = NULL) {
+ if (!is.numeric(minSupport) || minSupport < 0 || minSupport > 1) {
+ stop("minSupport should be a number [0, 1].")
+ }
+ if (!is.numeric(minConfidence) || minConfidence < 0 || minConfidence > 1) {
+ stop("minConfidence should be a number [0, 1].")
+ }
+ if (!is.null(numPartitions)) {
+ numPartitions <- as.integer(numPartitions)
+ stopifnot(numPartitions > 0)
+ }
+
+ jobj <- callJStatic("org.apache.spark.ml.r.FPGrowthWrapper", "fit",
+ data@sdf, as.numeric(minSupport), as.numeric(minConfidence),
+ itemsCol, numPartitions)
+ new("FPGrowthModel", jobj = jobj)
+ })
+
+# Get frequent itemsets.
+
+#' @param object a fitted FPGrowth model.
+#' @return A \code{SparkDataFrame} with frequent itemsets.
+#' The \code{SparkDataFrame} contains two columns:
+#' \code{items} (an array of the same type as the input column)
+#' and \code{freq} (frequency of the itemset).
+#' @rdname spark.fpGrowth
+#' @aliases freqItemsets,FPGrowthModel-method
+#' @export
+#' @note spark.freqItemsets(FPGrowthModel) since 2.2.0
+setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"),
+ function(object) {
+ dataFrame(callJMethod(object@jobj, "freqItemsets"))
+ })
+
+# Get association rules.
+
+#' @return A \code{SparkDataFrame} with association rules.
+#' The \code{SparkDataFrame} contains three columns:
+#' \code{antecedent} (an array of the same type as the input column),
+#' \code{consequent} (an array of the same type as the input column),
+#' and \code{condfidence} (confidence).
+#' @rdname spark.fpGrowth
+#' @aliases associationRules,FPGrowthModel-method
+#' @export
+#' @note spark.associationRules(FPGrowthModel) since 2.2.0
+setMethod("spark.associationRules", signature(object = "FPGrowthModel"),
+ function(object) {
+ dataFrame(callJMethod(object@jobj, "associationRules"))
+ })
+
+# Makes predictions based on generated association rules
+
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns a SparkDataFrame containing predicted values.
+#' @rdname spark.fpGrowth
+#' @aliases predict,FPGrowthModel-method
+#' @export
+#' @note predict(FPGrowthModel) since 2.2.0
+setMethod("predict", signature(object = "FPGrowthModel"),
+ function(object, newData) {
+ predict_internal(object, newData)
+ })
+
+# Saves the FPGrowth model to the output path.
+
+#' @param path the directory where the model is saved.
+#' @param overwrite logical value indicating whether to overwrite if the output path
+#' already exists. Default is FALSE which means throw exception
+#' if the output path exists.
+#' @rdname spark.fpGrowth
+#' @aliases write.ml,FPGrowthModel,character-method
+#' @export
+#' @seealso \link{read.ml}
+#' @note write.ml(FPGrowthModel, character) since 2.2.0
+setMethod("write.ml", signature(object = "FPGrowthModel", path = "character"),
+ function(object, path, overwrite = FALSE) {
+ write_internal(object, path, overwrite)
+ })
diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R
index 04a0a6f944..5dfef86250 100644
--- a/R/pkg/R/mllib_utils.R
+++ b/R/pkg/R/mllib_utils.R
@@ -118,6 +118,8 @@ read.ml <- function(path) {
new("BisectingKMeansModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LinearSVCWrapper")) {
new("LinearSVCModel", jobj = jobj)
+ } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FPGrowthWrapper")) {
+ new("FPGrowthModel", jobj = jobj)
} else {
stop("Unsupported model: ", jobj)
}
diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R b/R/pkg/inst/tests/testthat/test_mllib_fpm.R
new file mode 100644
index 0000000000..c38f113389
--- /dev/null
+++ b/R/pkg/inst/tests/testthat/test_mllib_fpm.R
@@ -0,0 +1,83 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+library(testthat)
+
+context("MLlib frequent pattern mining")
+
+# Tests for MLlib frequent pattern mining algorithms in SparkR
+sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+
+test_that("spark.fpGrowth", {
+ data <- selectExpr(createDataFrame(data.frame(items = c(
+ "1,2",
+ "1,2",
+ "1,2,3",
+ "1,3"
+ ))), "split(items, ',') as items")
+
+ model <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8, numPartitions = 1)
+
+ itemsets <- collect(spark.freqItemsets(model))
+
+ expected_itemsets <- data.frame(
+ items = I(list(list("3"), list("3", "1"), list("2"), list("2", "1"), list("1"))),
+ freq = c(2, 2, 3, 3, 4)
+ )
+
+ expect_equivalent(expected_itemsets, itemsets)
+
+ expected_association_rules <- data.frame(
+ antecedent = I(list(list("2"), list("3"))),
+ consequent = I(list(list("1"), list("1"))),
+ confidence = c(1, 1)
+ )
+
+ expect_equivalent(expected_association_rules, collect(spark.associationRules(model)))
+
+ new_data <- selectExpr(createDataFrame(data.frame(items = c(
+ "1,2",
+ "1,3",
+ "2,3"
+ ))), "split(items, ',') as items")
+
+ expected_predictions <- data.frame(
+ items = I(list(list("1", "2"), list("1", "3"), list("2", "3"))),
+ prediction = I(list(list(), list(), list("1")))
+ )
+
+ expect_equivalent(expected_predictions, collect(predict(model, new_data)))
+
+ modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp")
+ write.ml(model, modelPath, overwrite = TRUE)
+ loaded_model <- read.ml(modelPath)
+
+ expect_equivalent(
+ itemsets,
+ collect(spark.freqItemsets(loaded_model)))
+
+ unlink(modelPath)
+
+ model_without_numpartitions <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8)
+ expect_equal(
+ count(spark.freqItemsets(model_without_numpartitions)),
+ count(spark.freqItemsets(model))
+ )
+
+})
+
+sparkR.session.stop()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala
new file mode 100644
index 0000000000..b8151d8d90
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.fpm.{FPGrowth, FPGrowthModel}
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class FPGrowthWrapper private (val fpGrowthModel: FPGrowthModel) extends MLWritable {
+ def freqItemsets: DataFrame = fpGrowthModel.freqItemsets
+ def associationRules: DataFrame = fpGrowthModel.associationRules
+
+ def transform(dataset: Dataset[_]): DataFrame = {
+ fpGrowthModel.transform(dataset)
+ }
+
+ override def write: MLWriter = new FPGrowthWrapper.FPGrowthWrapperWriter(this)
+}
+
+private[r] object FPGrowthWrapper extends MLReadable[FPGrowthWrapper] {
+
+ def fit(
+ data: DataFrame,
+ minSupport: Double,
+ minConfidence: Double,
+ itemsCol: String,
+ numPartitions: Integer): FPGrowthWrapper = {
+ val fpGrowth = new FPGrowth()
+ .setMinSupport(minSupport)
+ .setMinConfidence(minConfidence)
+ .setItemsCol(itemsCol)
+
+ if (numPartitions != null && numPartitions > 0) {
+ fpGrowth.setNumPartitions(numPartitions)
+ }
+
+ val fpGrowthModel = fpGrowth.fit(data)
+
+ new FPGrowthWrapper(fpGrowthModel)
+ }
+
+ override def read: MLReader[FPGrowthWrapper] = new FPGrowthWrapperReader
+
+ class FPGrowthWrapperReader extends MLReader[FPGrowthWrapper] {
+ override def load(path: String): FPGrowthWrapper = {
+ val modelPath = new Path(path, "model").toString
+ val fPGrowthModel = FPGrowthModel.load(modelPath)
+
+ new FPGrowthWrapper(fPGrowthModel)
+ }
+ }
+
+ class FPGrowthWrapperWriter(instance: FPGrowthWrapper) extends MLWriter {
+ override protected def saveImpl(path: String): Unit = {
+ val modelPath = new Path(path, "model").toString
+ val rMetadataPath = new Path(path, "rMetadata").toString
+
+ val rMetadataJson: String = compact(render(
+ "class" -> instance.getClass.getName
+ ))
+
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+ instance.fpGrowthModel.save(modelPath)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index 358e522dfe..b30ce12bc6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -68,6 +68,8 @@ private[r] object RWrappers extends MLReader[Object] {
BisectingKMeansWrapper.load(path)
case "org.apache.spark.ml.r.LinearSVCWrapper" =>
LinearSVCWrapper.load(path)
+ case "org.apache.spark.ml.r.FPGrowthWrapper" =>
+ FPGrowthWrapper.load(path)
case _ =>
throw new SparkException(s"SparkR read.ml does not support load $className")
}