From acac7a508a29d0f75d86ee2e4ca83ebf01a36cf8 Mon Sep 17 00:00:00 2001 From: Junyang Qian Date: Fri, 19 Aug 2016 14:24:09 -0700 Subject: [SPARK-16443][SPARKR] Alternating Least Squares (ALS) wrapper ## What changes were proposed in this pull request? Add Alternating Least Squares wrapper in SparkR. Unit tests have been updated. ## How was this patch tested? SparkR unit tests. (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) ![screen shot 2016-07-27 at 3 50 31 pm](https://cloud.githubusercontent.com/assets/15318264/17195347/f7a6352a-5411-11e6-8e21-61a48070192a.png) ![screen shot 2016-07-27 at 3 50 46 pm](https://cloud.githubusercontent.com/assets/15318264/17195348/f7a7d452-5411-11e6-845f-6d292283bc28.png) Author: Junyang Qian Closes #14384 from junyangq/SPARK-16443. --- .../scala/org/apache/spark/ml/r/ALSWrapper.scala | 119 +++++++++++++++++++++ .../scala/org/apache/spark/ml/r/RWrappers.scala | 2 + 2 files changed, 121 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala new file mode 100644 index 0000000000..ad13cced46 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.recommendation.{ALS, ALSModel} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class ALSWrapper private ( + val alsModel: ALSModel, + val ratingCol: String) extends MLWritable { + + lazy val userCol: String = alsModel.getUserCol + lazy val itemCol: String = alsModel.getItemCol + lazy val userFactors: DataFrame = alsModel.userFactors + lazy val itemFactors: DataFrame = alsModel.itemFactors + lazy val rank: Int = alsModel.rank + + def transform(dataset: Dataset[_]): DataFrame = { + alsModel.transform(dataset) + } + + override def write: MLWriter = new ALSWrapper.ALSWrapperWriter(this) +} + +private[r] object ALSWrapper extends MLReadable[ALSWrapper] { + + def fit( // scalastyle:ignore + data: DataFrame, + ratingCol: String, + userCol: String, + itemCol: String, + rank: Int, + regParam: Double, + maxIter: Int, + implicitPrefs: Boolean, + alpha: Double, + nonnegative: Boolean, + numUserBlocks: Int, + numItemBlocks: Int, + checkpointInterval: Int, + seed: Int): ALSWrapper = { + + val als = new ALS() + .setRatingCol(ratingCol) + .setUserCol(userCol) + .setItemCol(itemCol) + .setRank(rank) + .setRegParam(regParam) + .setMaxIter(maxIter) + .setImplicitPrefs(implicitPrefs) + .setAlpha(alpha) + .setNonnegative(nonnegative) + .setNumBlocks(numUserBlocks) + .setNumItemBlocks(numItemBlocks) + .setCheckpointInterval(checkpointInterval) + .setSeed(seed.toLong) + + val alsModel: ALSModel = als.fit(data) + + new ALSWrapper(alsModel, ratingCol) + } + + override def read: MLReader[ALSWrapper] = new ALSWrapperReader + + override def load(path: String): ALSWrapper = super.load(path) + + class ALSWrapperWriter(instance: ALSWrapper) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val modelPath = new Path(path, "model").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("ratingCol" -> instance.ratingCol) + val rMetadataJson: String = compact(render(rMetadata)) + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + + instance.alsModel.save(modelPath) + } + } + + class ALSWrapperReader extends MLReader[ALSWrapper] { + + override def load(path: String): ALSWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val modelPath = new Path(path, "model").toString + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val ratingCol = (rMetadata \ "ratingCol").extract[String] + val alsModel = ALSModel.load(modelPath) + + new ALSWrapper(alsModel, ratingCol) + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index e23af51df5..51a65f7fc4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -50,6 +50,8 @@ private[r] object RWrappers extends MLReader[Object] { IsotonicRegressionWrapper.load(path) case "org.apache.spark.ml.r.GaussianMixtureWrapper" => GaussianMixtureWrapper.load(path) + case "org.apache.spark.ml.r.ALSWrapper" => + ALSWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } -- cgit v1.2.3