aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorSue Ann Hong <sueann@databricks.com>2017-03-05 16:49:31 -0800
committerJoseph K. Bradley <joseph@databricks.com>2017-03-05 16:49:31 -0800
commit70f9d7f71c63d2b1fdfed75cb7a59285c272a62b (patch)
treeeddf4cd95e3ac61564b3d9f3a46b83755d965156 /mllib/src/main
parent369a148e591bb16ec7da54867610b207602cd698 (diff)
downloadspark-70f9d7f71c63d2b1fdfed75cb7a59285c272a62b.tar.gz
spark-70f9d7f71c63d2b1fdfed75cb7a59285c272a62b.tar.bz2
spark-70f9d7f71c63d2b1fdfed75cb7a59285c272a62b.zip
[SPARK-19535][ML] RecommendForAllUsers RecommendForAllItems for ALS on Dataframe
## What changes were proposed in this pull request? This is a simple implementation of RecommendForAllUsers & RecommendForAllItems for the Dataframe version of ALS. It uses Dataframe operations (not a wrapper on the RDD implementation). Haven't benchmarked against a wrapper, but unit test examples do work. ## How was this patch tested? Unit tests ``` $ build/sbt > mllib/testOnly *ALSSuite -- -z "recommendFor" > mllib/testOnly ``` Author: Your Name <you@example.com> Author: sueann <sueann@databricks.com> Closes #17090 from sueann/SPARK-19535.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala79
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala60
2 files changed, 130 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 799e881fad..60dd736705 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -40,7 +40,8 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.CholeskyDecomposition
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
@@ -284,18 +285,20 @@ class ALSModel private[ml] (
@Since("2.2.0")
def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value)
+ private val predict = udf { (featuresA: Seq[Float], featuresB: Seq[Float]) =>
+ if (featuresA != null && featuresB != null) {
+ // TODO(SPARK-19759): try dot-producting on Seqs or another non-converted type for
+ // potential optimization.
+ blas.sdot(rank, featuresA.toArray, 1, featuresB.toArray, 1)
+ } else {
+ Float.NaN
+ }
+ }
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
- // Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
- val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
- if (userFeatures != null && itemFeatures != null) {
- blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1)
- } else {
- Float.NaN
- }
- }
val predictions = dataset
.join(userFactors,
checkedCast(dataset($(userCol))) === userFactors("id"), "left")
@@ -327,6 +330,64 @@ class ALSModel private[ml] (
@Since("1.6.0")
override def write: MLWriter = new ALSModel.ALSModelWriter(this)
+
+ /**
+ * Returns top `numItems` items recommended for each user, for all users.
+ * @param numItems max number of recommendations for each user
+ * @return a DataFrame of (userCol: Int, recommendations), where recommendations are
+ * stored as an array of (itemCol: Int, rating: Float) Rows.
+ */
+ @Since("2.2.0")
+ def recommendForAllUsers(numItems: Int): DataFrame = {
+ recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems)
+ }
+
+ /**
+ * Returns top `numUsers` users recommended for each item, for all items.
+ * @param numUsers max number of recommendations for each item
+ * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are
+ * stored as an array of (userCol: Int, rating: Float) Rows.
+ */
+ @Since("2.2.0")
+ def recommendForAllItems(numUsers: Int): DataFrame = {
+ recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers)
+ }
+
+ /**
+ * Makes recommendations for all users (or items).
+ * @param srcFactors src factors for which to generate recommendations
+ * @param dstFactors dst factors used to make recommendations
+ * @param srcOutputColumn name of the column for the source ID in the output DataFrame
+ * @param dstOutputColumn name of the column for the destination ID in the output DataFrame
+ * @param num max number of recommendations for each record
+ * @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are
+ * stored as an array of (dstOutputColumn: Int, rating: Float) Rows.
+ */
+ private def recommendForAll(
+ srcFactors: DataFrame,
+ dstFactors: DataFrame,
+ srcOutputColumn: String,
+ dstOutputColumn: String,
+ num: Int): DataFrame = {
+ import srcFactors.sparkSession.implicits._
+
+ val ratings = srcFactors.crossJoin(dstFactors)
+ .select(
+ srcFactors("id"),
+ dstFactors("id"),
+ predict(srcFactors("features"), dstFactors("features")))
+ // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
+ val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
+ val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn)
+ .toDF("id", "recommendations")
+
+ val arrayType = ArrayType(
+ new StructType()
+ .add(dstOutputColumn, IntegerType)
+ .add("rating", FloatType)
+ )
+ recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType)
+ }
}
@Since("1.6.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala
new file mode 100644
index 0000000000..517179c0eb
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.recommendation
+
+import scala.language.implicitConversions
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.{Encoder, Encoders}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.util.BoundedPriorityQueue
+
+
+/**
+ * Works on rows of the form (K1, K2, V) where K1 & K2 are IDs and V is the score value. Finds
+ * the top `num` K2 items based on the given Ordering.
+ */
+private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: TypeTag]
+ (num: Int, ord: Ordering[(K2, V)])
+ extends Aggregator[(K1, K2, V), BoundedPriorityQueue[(K2, V)], Array[(K2, V)]] {
+
+ override def zero: BoundedPriorityQueue[(K2, V)] = new BoundedPriorityQueue[(K2, V)](num)(ord)
+
+ override def reduce(
+ q: BoundedPriorityQueue[(K2, V)],
+ a: (K1, K2, V)): BoundedPriorityQueue[(K2, V)] = {
+ q += {(a._2, a._3)}
+ }
+
+ override def merge(
+ q1: BoundedPriorityQueue[(K2, V)],
+ q2: BoundedPriorityQueue[(K2, V)]): BoundedPriorityQueue[(K2, V)] = {
+ q1 ++= q2
+ }
+
+ override def finish(r: BoundedPriorityQueue[(K2, V)]): Array[(K2, V)] = {
+ r.toArray.sorted(ord.reverse)
+ }
+
+ override def bufferEncoder: Encoder[BoundedPriorityQueue[(K2, V)]] = {
+ Encoders.kryo[BoundedPriorityQueue[(K2, V)]]
+ }
+
+ override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]()
+}