aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJunyang Qian <junyangq@databricks.com>2016-08-19 14:24:09 -0700
committerXiangrui Meng <meng@databricks.com>2016-08-19 14:24:09 -0700
commitacac7a508a29d0f75d86ee2e4ca83ebf01a36cf8 (patch)
treebf01165da59ed904073196844195484318459d81 /mllib
parentcf0cce90364d17afe780ff9a5426dfcefa298535 (diff)
downloadspark-acac7a508a29d0f75d86ee2e4ca83ebf01a36cf8.tar.gz
spark-acac7a508a29d0f75d86ee2e4ca83ebf01a36cf8.tar.bz2
spark-acac7a508a29d0f75d86ee2e4ca83ebf01a36cf8.zip
[SPARK-16443][SPARKR] Alternating Least Squares (ALS) wrapper
## What changes were proposed in this pull request? Add Alternating Least Squares wrapper in SparkR. Unit tests have been updated. ## How was this patch tested? SparkR unit tests. (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) ![screen shot 2016-07-27 at 3 50 31 pm](https://cloud.githubusercontent.com/assets/15318264/17195347/f7a6352a-5411-11e6-8e21-61a48070192a.png) ![screen shot 2016-07-27 at 3 50 46 pm](https://cloud.githubusercontent.com/assets/15318264/17195348/f7a7d452-5411-11e6-845f-6d292283bc28.png) Author: Junyang Qian <junyangq@databricks.com> Closes #14384 from junyangq/SPARK-16443.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala119
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala2
2 files changed, 121 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala
new file mode 100644
index 0000000000..ad13cced46
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.recommendation.{ALS, ALSModel}
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class ALSWrapper private (
+ val alsModel: ALSModel,
+ val ratingCol: String) extends MLWritable {
+
+ lazy val userCol: String = alsModel.getUserCol
+ lazy val itemCol: String = alsModel.getItemCol
+ lazy val userFactors: DataFrame = alsModel.userFactors
+ lazy val itemFactors: DataFrame = alsModel.itemFactors
+ lazy val rank: Int = alsModel.rank
+
+ def transform(dataset: Dataset[_]): DataFrame = {
+ alsModel.transform(dataset)
+ }
+
+ override def write: MLWriter = new ALSWrapper.ALSWrapperWriter(this)
+}
+
+private[r] object ALSWrapper extends MLReadable[ALSWrapper] {
+
+ def fit( // scalastyle:ignore
+ data: DataFrame,
+ ratingCol: String,
+ userCol: String,
+ itemCol: String,
+ rank: Int,
+ regParam: Double,
+ maxIter: Int,
+ implicitPrefs: Boolean,
+ alpha: Double,
+ nonnegative: Boolean,
+ numUserBlocks: Int,
+ numItemBlocks: Int,
+ checkpointInterval: Int,
+ seed: Int): ALSWrapper = {
+
+ val als = new ALS()
+ .setRatingCol(ratingCol)
+ .setUserCol(userCol)
+ .setItemCol(itemCol)
+ .setRank(rank)
+ .setRegParam(regParam)
+ .setMaxIter(maxIter)
+ .setImplicitPrefs(implicitPrefs)
+ .setAlpha(alpha)
+ .setNonnegative(nonnegative)
+ .setNumBlocks(numUserBlocks)
+ .setNumItemBlocks(numItemBlocks)
+ .setCheckpointInterval(checkpointInterval)
+ .setSeed(seed.toLong)
+
+ val alsModel: ALSModel = als.fit(data)
+
+ new ALSWrapper(alsModel, ratingCol)
+ }
+
+ override def read: MLReader[ALSWrapper] = new ALSWrapperReader
+
+ override def load(path: String): ALSWrapper = super.load(path)
+
+ class ALSWrapperWriter(instance: ALSWrapper) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val modelPath = new Path(path, "model").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("ratingCol" -> instance.ratingCol)
+ val rMetadataJson: String = compact(render(rMetadata))
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+ instance.alsModel.save(modelPath)
+ }
+ }
+
+ class ALSWrapperReader extends MLReader[ALSWrapper] {
+
+ override def load(path: String): ALSWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val modelPath = new Path(path, "model").toString
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val ratingCol = (rMetadata \ "ratingCol").extract[String]
+ val alsModel = ALSModel.load(modelPath)
+
+ new ALSWrapper(alsModel, ratingCol)
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index e23af51df5..51a65f7fc4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -50,6 +50,8 @@ private[r] object RWrappers extends MLReader[Object] {
IsotonicRegressionWrapper.load(path)
case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>
GaussianMixtureWrapper.load(path)
+ case "org.apache.spark.ml.r.ALSWrapper" =>
+ ALSWrapper.load(path)
case _ =>
throw new SparkException(s"SparkR read.ml does not support load $className")
}