aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorSue Ann Hong <sueann@databricks.com>2017-03-05 16:49:31 -0800
committerJoseph K. Bradley <joseph@databricks.com>2017-03-05 16:49:31 -0800
commit70f9d7f71c63d2b1fdfed75cb7a59285c272a62b (patch)
treeeddf4cd95e3ac61564b3d9f3a46b83755d965156 /mllib
parent369a148e591bb16ec7da54867610b207602cd698 (diff)
downloadspark-70f9d7f71c63d2b1fdfed75cb7a59285c272a62b.tar.gz
spark-70f9d7f71c63d2b1fdfed75cb7a59285c272a62b.tar.bz2
spark-70f9d7f71c63d2b1fdfed75cb7a59285c272a62b.zip
[SPARK-19535][ML] RecommendForAllUsers RecommendForAllItems for ALS on Dataframe
## What changes were proposed in this pull request? This is a simple implementation of RecommendForAllUsers & RecommendForAllItems for the Dataframe version of ALS. It uses Dataframe operations (not a wrapper on the RDD implementation). Haven't benchmarked against a wrapper, but unit test examples do work. ## How was this patch tested? Unit tests ``` $ build/sbt > mllib/testOnly *ALSSuite -- -z "recommendFor" > mllib/testOnly ``` Author: Your Name <you@example.com> Author: sueann <sueann@databricks.com> Closes #17090 from sueann/SPARK-19535.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala79
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala60
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala94
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala73
4 files changed, 297 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 799e881fad..60dd736705 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -40,7 +40,8 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.CholeskyDecomposition
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
@@ -284,18 +285,20 @@ class ALSModel private[ml] (
@Since("2.2.0")
def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value)
+ private val predict = udf { (featuresA: Seq[Float], featuresB: Seq[Float]) =>
+ if (featuresA != null && featuresB != null) {
+ // TODO(SPARK-19759): try dot-producting on Seqs or another non-converted type for
+ // potential optimization.
+ blas.sdot(rank, featuresA.toArray, 1, featuresB.toArray, 1)
+ } else {
+ Float.NaN
+ }
+ }
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
- // Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
- val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
- if (userFeatures != null && itemFeatures != null) {
- blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1)
- } else {
- Float.NaN
- }
- }
val predictions = dataset
.join(userFactors,
checkedCast(dataset($(userCol))) === userFactors("id"), "left")
@@ -327,6 +330,64 @@ class ALSModel private[ml] (
@Since("1.6.0")
override def write: MLWriter = new ALSModel.ALSModelWriter(this)
+
+ /**
+ * Returns top `numItems` items recommended for each user, for all users.
+ * @param numItems max number of recommendations for each user
+ * @return a DataFrame of (userCol: Int, recommendations), where recommendations are
+ * stored as an array of (itemCol: Int, rating: Float) Rows.
+ */
+ @Since("2.2.0")
+ def recommendForAllUsers(numItems: Int): DataFrame = {
+ recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems)
+ }
+
+ /**
+ * Returns top `numUsers` users recommended for each item, for all items.
+ * @param numUsers max number of recommendations for each item
+ * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are
+ * stored as an array of (userCol: Int, rating: Float) Rows.
+ */
+ @Since("2.2.0")
+ def recommendForAllItems(numUsers: Int): DataFrame = {
+ recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers)
+ }
+
+ /**
+ * Makes recommendations for all users (or items).
+ * @param srcFactors src factors for which to generate recommendations
+ * @param dstFactors dst factors used to make recommendations
+ * @param srcOutputColumn name of the column for the source ID in the output DataFrame
+ * @param dstOutputColumn name of the column for the destination ID in the output DataFrame
+ * @param num max number of recommendations for each record
+ * @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are
+ * stored as an array of (dstOutputColumn: Int, rating: Float) Rows.
+ */
+ private def recommendForAll(
+ srcFactors: DataFrame,
+ dstFactors: DataFrame,
+ srcOutputColumn: String,
+ dstOutputColumn: String,
+ num: Int): DataFrame = {
+ import srcFactors.sparkSession.implicits._
+
+ val ratings = srcFactors.crossJoin(dstFactors)
+ .select(
+ srcFactors("id"),
+ dstFactors("id"),
+ predict(srcFactors("features"), dstFactors("features")))
+ // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
+ val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
+ val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn)
+ .toDF("id", "recommendations")
+
+ val arrayType = ArrayType(
+ new StructType()
+ .add(dstOutputColumn, IntegerType)
+ .add("rating", FloatType)
+ )
+ recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType)
+ }
}
@Since("1.6.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala
new file mode 100644
index 0000000000..517179c0eb
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.recommendation
+
+import scala.language.implicitConversions
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.{Encoder, Encoders}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.util.BoundedPriorityQueue
+
+
+/**
+ * Works on rows of the form (K1, K2, V) where K1 & K2 are IDs and V is the score value. Finds
+ * the top `num` K2 items based on the given Ordering.
+ */
+private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: TypeTag]
+ (num: Int, ord: Ordering[(K2, V)])
+ extends Aggregator[(K1, K2, V), BoundedPriorityQueue[(K2, V)], Array[(K2, V)]] {
+
+ override def zero: BoundedPriorityQueue[(K2, V)] = new BoundedPriorityQueue[(K2, V)](num)(ord)
+
+ override def reduce(
+ q: BoundedPriorityQueue[(K2, V)],
+ a: (K1, K2, V)): BoundedPriorityQueue[(K2, V)] = {
+ q += {(a._2, a._3)}
+ }
+
+ override def merge(
+ q1: BoundedPriorityQueue[(K2, V)],
+ q2: BoundedPriorityQueue[(K2, V)]): BoundedPriorityQueue[(K2, V)] = {
+ q1 ++= q2
+ }
+
+ override def finish(r: BoundedPriorityQueue[(K2, V)]): Array[(K2, V)] = {
+ r.toArray.sorted(ord.reverse)
+ }
+
+ override def bufferEncoder: Encoder[BoundedPriorityQueue[(K2, V)]] = {
+ Encoders.kryo[BoundedPriorityQueue[(K2, V)]]
+ }
+
+ override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]()
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index c8228dd004..e494ea89e6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -22,6 +22,7 @@ import java.util.Random
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.WrappedArray
import scala.collection.JavaConverters._
import scala.language.existentials
@@ -660,6 +661,99 @@ class ALSSuite
model.setColdStartStrategy(s).transform(data)
}
}
+
+ private def getALSModel = {
+ val spark = this.spark
+ import spark.implicits._
+
+ val userFactors = Seq(
+ (0, Array(6.0f, 4.0f)),
+ (1, Array(3.0f, 4.0f)),
+ (2, Array(3.0f, 6.0f))
+ ).toDF("id", "features")
+ val itemFactors = Seq(
+ (3, Array(5.0f, 6.0f)),
+ (4, Array(6.0f, 2.0f)),
+ (5, Array(3.0f, 6.0f)),
+ (6, Array(4.0f, 1.0f))
+ ).toDF("id", "features")
+ val als = new ALS().setRank(2)
+ new ALSModel(als.uid, als.getRank, userFactors, itemFactors)
+ .setUserCol("user")
+ .setItemCol("item")
+ }
+
+ test("recommendForAllUsers with k < num_items") {
+ val topItems = getALSModel.recommendForAllUsers(2)
+ assert(topItems.count() == 3)
+ assert(topItems.columns.contains("user"))
+
+ val expected = Map(
+ 0 -> Array((3, 54f), (4, 44f)),
+ 1 -> Array((3, 39f), (5, 33f)),
+ 2 -> Array((3, 51f), (5, 45f))
+ )
+ checkRecommendations(topItems, expected, "item")
+ }
+
+ test("recommendForAllUsers with k = num_items") {
+ val topItems = getALSModel.recommendForAllUsers(4)
+ assert(topItems.count() == 3)
+ assert(topItems.columns.contains("user"))
+
+ val expected = Map(
+ 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)),
+ 1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)),
+ 2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f))
+ )
+ checkRecommendations(topItems, expected, "item")
+ }
+
+ test("recommendForAllItems with k < num_users") {
+ val topUsers = getALSModel.recommendForAllItems(2)
+ assert(topUsers.count() == 4)
+ assert(topUsers.columns.contains("item"))
+
+ val expected = Map(
+ 3 -> Array((0, 54f), (2, 51f)),
+ 4 -> Array((0, 44f), (2, 30f)),
+ 5 -> Array((2, 45f), (0, 42f)),
+ 6 -> Array((0, 28f), (2, 18f))
+ )
+ checkRecommendations(topUsers, expected, "user")
+ }
+
+ test("recommendForAllItems with k = num_users") {
+ val topUsers = getALSModel.recommendForAllItems(3)
+ assert(topUsers.count() == 4)
+ assert(topUsers.columns.contains("item"))
+
+ val expected = Map(
+ 3 -> Array((0, 54f), (2, 51f), (1, 39f)),
+ 4 -> Array((0, 44f), (2, 30f), (1, 26f)),
+ 5 -> Array((2, 45f), (0, 42f), (1, 33f)),
+ 6 -> Array((0, 28f), (2, 18f), (1, 16f))
+ )
+ checkRecommendations(topUsers, expected, "user")
+ }
+
+ private def checkRecommendations(
+ topK: DataFrame,
+ expected: Map[Int, Array[(Int, Float)]],
+ dstColName: String): Unit = {
+ val spark = this.spark
+ import spark.implicits._
+
+ assert(topK.columns.contains("recommendations"))
+ topK.as[(Int, Seq[(Int, Float)])].collect().foreach { case (id: Int, recs: Seq[(Int, Float)]) =>
+ assert(recs === expected(id))
+ }
+ topK.collect().foreach { row =>
+ val recs = row.getAs[WrappedArray[Row]]("recommendations")
+ assert(recs(0).fieldIndex(dstColName) == 0)
+ assert(recs(0).fieldIndex("rating") == 1)
+ }
+ }
}
class ALSCleanerSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala
new file mode 100644
index 0000000000..5e763a8e90
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.recommendation
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Dataset
+
+
+class TopByKeyAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ private def getTopK(k: Int): Dataset[(Int, Array[(Int, Float)])] = {
+ val sqlContext = spark.sqlContext
+ import sqlContext.implicits._
+
+ val topKAggregator = new TopByKeyAggregator[Int, Int, Float](k, Ordering.by(_._2))
+ Seq(
+ (0, 3, 54f),
+ (0, 4, 44f),
+ (0, 5, 42f),
+ (0, 6, 28f),
+ (1, 3, 39f),
+ (2, 3, 51f),
+ (2, 5, 45f),
+ (2, 6, 18f)
+ ).toDS().groupByKey(_._1).agg(topKAggregator.toColumn)
+ }
+
+ test("topByKey with k < #items") {
+ val topK = getTopK(2)
+ assert(topK.count() === 3)
+
+ val expected = Map(
+ 0 -> Array((3, 54f), (4, 44f)),
+ 1 -> Array((3, 39f)),
+ 2 -> Array((3, 51f), (5, 45f))
+ )
+ checkTopK(topK, expected)
+ }
+
+ test("topByKey with k > #items") {
+ val topK = getTopK(5)
+ assert(topK.count() === 3)
+
+ val expected = Map(
+ 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)),
+ 1 -> Array((3, 39f)),
+ 2 -> Array((3, 51f), (5, 45f), (6, 18f))
+ )
+ checkTopK(topK, expected)
+ }
+
+ private def checkTopK(
+ topK: Dataset[(Int, Array[(Int, Float)])],
+ expected: Map[Int, Array[(Int, Float)]]): Unit = {
+ topK.collect().foreach { case (id, recs) => assert(recs === expected(id)) }
+ }
+}