aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala55
1 files changed, 49 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 93aa41e499..43d219a49c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -22,6 +22,7 @@ import java.lang.{Integer => JavaInteger}
import scala.collection.mutable
+import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.hadoop.fs.Path
import org.json4s._
@@ -80,6 +81,30 @@ class MatrixFactorizationModel(
}
/**
+ * Return approximate numbers of users and products in the given usersProducts tuples.
+ * This method is based on `countApproxDistinct` in class `RDD`.
+ *
+ * @param usersProducts RDD of (user, product) pairs.
+ * @return approximate numbers of users and products.
+ */
+ private[this] def countApproxDistinctUserProduct(usersProducts: RDD[(Int, Int)]): (Long, Long) = {
+ val zeroCounterUser = new HyperLogLogPlus(4, 0)
+ val zeroCounterProduct = new HyperLogLogPlus(4, 0)
+ val aggregated = usersProducts.aggregate((zeroCounterUser, zeroCounterProduct))(
+ (hllTuple: (HyperLogLogPlus, HyperLogLogPlus), v: (Int, Int)) => {
+ hllTuple._1.offer(v._1)
+ hllTuple._2.offer(v._2)
+ hllTuple
+ },
+ (h1: (HyperLogLogPlus, HyperLogLogPlus), h2: (HyperLogLogPlus, HyperLogLogPlus)) => {
+ h1._1.addAll(h2._1)
+ h1._2.addAll(h2._2)
+ h1
+ })
+ (aggregated._1.cardinality(), aggregated._2.cardinality())
+ }
+
+ /**
* Predict the rating of many users for many products.
* The output RDD has an element per each element in the input RDD (including all duplicates)
* unless a user or product is missing in the training set.
@@ -88,12 +113,30 @@ class MatrixFactorizationModel(
* @return RDD of Ratings.
*/
def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = {
- val users = userFeatures.join(usersProducts).map {
- case (user, (uFeatures, product)) => (product, (user, uFeatures))
- }
- users.join(productFeatures).map {
- case (product, ((user, uFeatures), pFeatures)) =>
- Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
+ // Previously the partitions of ratings are only based on the given products.
+ // So if the usersProducts given for prediction contains only few products or
+ // even one product, the generated ratings will be pushed into few or single partition
+ // and can't use high parallelism.
+ // Here we calculate approximate numbers of users and products. Then we decide the
+ // partitions should be based on users or products.
+ val (usersCount, productsCount) = countApproxDistinctUserProduct(usersProducts)
+
+ if (usersCount < productsCount) {
+ val users = userFeatures.join(usersProducts).map {
+ case (user, (uFeatures, product)) => (product, (user, uFeatures))
+ }
+ users.join(productFeatures).map {
+ case (product, ((user, uFeatures), pFeatures)) =>
+ Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
+ }
+ } else {
+ val products = productFeatures.join(usersProducts.map(_.swap)).map {
+ case (product, (pFeatures, user)) => (user, (product, pFeatures))
+ }
+ products.join(userFeatures).map {
+ case (user, ((product, pFeatures), uFeatures)) =>
+ Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
+ }
}
}