aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-02-01 14:13:31 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-01 14:13:31 -0800
commit4a171225ba628192a5ae43a99dc50508cf12491c (patch)
tree09c6b6dd42d849dd8bd7f456acad0d182d2f8fe4 /mllib
parentbdb0680d37614ccdec8933d2dec53793825e43d7 (diff)
downloadspark-4a171225ba628192a5ae43a99dc50508cf12491c.tar.gz
spark-4a171225ba628192a5ae43a99dc50508cf12491c.tar.bz2
spark-4a171225ba628192a5ae43a99dc50508cf12491c.zip
[SPARK-5424][MLLIB] make the new ALS impl take generic ID types
This PR makes the ALS implementation take generic ID types, e.g., Long and String, and expose it as a developer API. TODO: - [x] make sure that specialization works (validated in profiler) srowen You may like this change:) I hit a Scala compiler bug with specialization. It compiles now but users and items must have the same type. I'm going to check whether specialization really works. Author: Xiangrui Meng <meng@databricks.com> Closes #4281 from mengxr/generic-als and squashes the following commits: 96072c3 [Xiangrui Meng] merge master 135f741 [Xiangrui Meng] minor update c2db5e5 [Xiangrui Meng] make test pass 86588e1 [Xiangrui Meng] use a single ID type for both users and items 74f1f73 [Xiangrui Meng] compile but runtime error at test e36469a [Xiangrui Meng] add classtags and make it compile 7a5aeb3 [Xiangrui Meng] UserType -> User, ItemType -> Item c8ee0bc [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into generic-als 72b5006 [Xiangrui Meng] remove generic from pipeline interface 8bbaea0 [Xiangrui Meng] make ALS take generic IDs
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala213
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala36
2 files changed, 146 insertions, 103 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index aaad548143..979a19d3b2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -20,16 +20,19 @@ package org.apache.spark.ml.recommendation
import java.{util => ju}
import scala.collection.mutable
+import scala.reflect.ClassTag
+import scala.util.Sorting
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
import org.netlib.util.intW
import org.apache.spark.{HashPartitioner, Logging, Partitioner}
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Column, DataFrame}
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Dsl._
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
import org.apache.spark.util.Utils
@@ -199,7 +202,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
val ratings = dataset
.select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType))
.map { row =>
- new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
+ Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
}
val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
@@ -215,10 +218,19 @@ class ALS extends Estimator[ALSModel] with ALSParams {
}
}
-private[recommendation] object ALS extends Logging {
+/**
+ * :: DeveloperApi ::
+ * An implementation of ALS that supports generic ID types, specialized for Int and Long. This is
+ * exposed as a developer API for users who do need other ID types. But it is not recommended
+ * because it increases the shuffle size and memory requirement during training. For simplicity,
+ * users and items must have the same type. The number of distinct users/items should be smaller
+ * than 2 billion.
+ */
+@DeveloperApi
+object ALS extends Logging {
/** Rating class for better code readability. */
- private[recommendation] case class Rating(user: Int, item: Int, rating: Float)
+ case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float)
/** Cholesky solver for least square problems. */
private[recommendation] class CholeskySolver {
@@ -285,7 +297,7 @@ private[recommendation] object ALS extends Logging {
/** Adds an observation. */
def add(a: Array[Float], b: Float): this.type = {
- require(a.size == k)
+ require(a.length == k)
copyToDouble(a)
blas.dspr(upper, k, 1.0, da, 1, ata)
blas.daxpy(k, b.toDouble, da, 1, atb, 1)
@@ -297,7 +309,7 @@ private[recommendation] object ALS extends Logging {
* Adds an observation with implicit feedback. Note that this does not increment the counter.
*/
def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = {
- require(a.size == k)
+ require(a.length == k)
// Extension to the original paper to handle b < 0. confidence is a function of |b| instead
// so that it is never negative.
val confidence = 1.0 + alpha * math.abs(b)
@@ -313,8 +325,8 @@ private[recommendation] object ALS extends Logging {
/** Merges another normal equation object. */
def merge(other: NormalEquation): this.type = {
require(other.k == k)
- blas.daxpy(ata.size, 1.0, other.ata, 1, ata, 1)
- blas.daxpy(atb.size, 1.0, other.atb, 1, atb, 1)
+ blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1)
+ blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1)
n += other.n
this
}
@@ -330,15 +342,16 @@ private[recommendation] object ALS extends Logging {
/**
* Implementation of the ALS algorithm.
*/
- private def train(
- ratings: RDD[Rating],
+ def train[ID: ClassTag](
+ ratings: RDD[Rating[ID]],
rank: Int = 10,
numUserBlocks: Int = 10,
numItemBlocks: Int = 10,
maxIter: Int = 10,
regParam: Double = 1.0,
implicitPrefs: Boolean = false,
- alpha: Double = 1.0): (RDD[(Int, Array[Float])], RDD[(Int, Array[Float])]) = {
+ alpha: Double = 1.0)(
+ implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
val userPart = new HashPartitioner(numUserBlocks)
val itemPart = new HashPartitioner(numItemBlocks)
val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions)
@@ -441,16 +454,15 @@ private[recommendation] object ALS extends Logging {
*
* @see [[LocalIndexEncoder]]
*/
- private[recommendation] case class InBlock(
- srcIds: Array[Int],
+ private[recommendation] case class InBlock[@specialized(Int, Long) ID: ClassTag](
+ srcIds: Array[ID],
dstPtrs: Array[Int],
dstEncodedIndices: Array[Int],
ratings: Array[Float]) {
/** Size of the block. */
- val size: Int = ratings.size
-
- require(dstEncodedIndices.size == size)
- require(dstPtrs.size == srcIds.size + 1)
+ def size: Int = ratings.length
+ require(dstEncodedIndices.length == size)
+ require(dstPtrs.length == srcIds.length + 1)
}
/**
@@ -460,7 +472,9 @@ private[recommendation] object ALS extends Logging {
* @param rank rank
* @return initialized factor blocks
*/
- private def initialize(inBlocks: RDD[(Int, InBlock)], rank: Int): RDD[(Int, FactorBlock)] = {
+ private def initialize[ID](
+ inBlocks: RDD[(Int, InBlock[ID])],
+ rank: Int): RDD[(Int, FactorBlock)] = {
// Choose a unit vector uniformly at random from the unit sphere, but from the
// "first quadrant" where all elements are nonnegative. This can be done by choosing
// elements distributed as Normal(0,1) and taking the absolute value, and then normalizing.
@@ -468,7 +482,7 @@ private[recommendation] object ALS extends Logging {
// (<1%) compared picking elements uniformly at random in [0,1].
inBlocks.map { case (srcBlockId, inBlock) =>
val random = new XORShiftRandom(srcBlockId)
- val factors = Array.fill(inBlock.srcIds.size) {
+ val factors = Array.fill(inBlock.srcIds.length) {
val factor = Array.fill(rank)(random.nextGaussian().toFloat)
val nrm = blas.snrm2(rank, factor, 1)
blas.sscal(rank, 1.0f / nrm, factor, 1)
@@ -481,26 +495,29 @@ private[recommendation] object ALS extends Logging {
/**
* A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays.
*/
- private[recommendation]
- case class RatingBlock(srcIds: Array[Int], dstIds: Array[Int], ratings: Array[Float]) {
+ private[recommendation] case class RatingBlock[@specialized(Int, Long) ID: ClassTag](
+ srcIds: Array[ID],
+ dstIds: Array[ID],
+ ratings: Array[Float]) {
/** Size of the block. */
- val size: Int = srcIds.size
- require(dstIds.size == size)
- require(ratings.size == size)
+ def size: Int = srcIds.length
+ require(dstIds.length == srcIds.length)
+ require(ratings.length == srcIds.length)
}
/**
* Builder for [[RatingBlock]]. [[mutable.ArrayBuilder]] is used to avoid boxing/unboxing.
*/
- private[recommendation] class RatingBlockBuilder extends Serializable {
+ private[recommendation] class RatingBlockBuilder[@specialized(Int, Long) ID: ClassTag]
+ extends Serializable {
- private val srcIds = mutable.ArrayBuilder.make[Int]
- private val dstIds = mutable.ArrayBuilder.make[Int]
+ private val srcIds = mutable.ArrayBuilder.make[ID]
+ private val dstIds = mutable.ArrayBuilder.make[ID]
private val ratings = mutable.ArrayBuilder.make[Float]
var size = 0
/** Adds a rating. */
- def add(r: Rating): this.type = {
+ def add(r: Rating[ID]): this.type = {
size += 1
srcIds += r.user
dstIds += r.item
@@ -509,8 +526,8 @@ private[recommendation] object ALS extends Logging {
}
/** Merges another [[RatingBlockBuilder]]. */
- def merge(other: RatingBlock): this.type = {
- size += other.srcIds.size
+ def merge(other: RatingBlock[ID]): this.type = {
+ size += other.srcIds.length
srcIds ++= other.srcIds
dstIds ++= other.dstIds
ratings ++= other.ratings
@@ -518,8 +535,8 @@ private[recommendation] object ALS extends Logging {
}
/** Builds a [[RatingBlock]]. */
- def build(): RatingBlock = {
- RatingBlock(srcIds.result(), dstIds.result(), ratings.result())
+ def build(): RatingBlock[ID] = {
+ RatingBlock[ID](srcIds.result(), dstIds.result(), ratings.result())
}
}
@@ -532,10 +549,10 @@ private[recommendation] object ALS extends Logging {
*
* @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock)
*/
- private def partitionRatings(
- ratings: RDD[Rating],
+ private def partitionRatings[ID: ClassTag](
+ ratings: RDD[Rating[ID]],
srcPart: Partitioner,
- dstPart: Partitioner): RDD[((Int, Int), RatingBlock)] = {
+ dstPart: Partitioner): RDD[((Int, Int), RatingBlock[ID])] = {
/* The implementation produces the same result as the following but generates less objects.
@@ -549,7 +566,7 @@ private[recommendation] object ALS extends Logging {
val numPartitions = srcPart.numPartitions * dstPart.numPartitions
ratings.mapPartitions { iter =>
- val builders = Array.fill(numPartitions)(new RatingBlockBuilder)
+ val builders = Array.fill(numPartitions)(new RatingBlockBuilder[ID])
iter.flatMap { r =>
val srcBlockId = srcPart.getPartition(r.user)
val dstBlockId = dstPart.getPartition(r.item)
@@ -570,7 +587,7 @@ private[recommendation] object ALS extends Logging {
}
}
}.groupByKey().mapValues { blocks =>
- val builder = new RatingBlockBuilder
+ val builder = new RatingBlockBuilder[ID]
blocks.foreach(builder.merge)
builder.build()
}.setName("ratingBlocks")
@@ -580,9 +597,11 @@ private[recommendation] object ALS extends Logging {
* Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples.
* @param encoder encoder for dst indices
*/
- private[recommendation] class UncompressedInBlockBuilder(encoder: LocalIndexEncoder) {
+ private[recommendation] class UncompressedInBlockBuilder[@specialized(Int, Long) ID: ClassTag](
+ encoder: LocalIndexEncoder)(
+ implicit ord: Ordering[ID]) {
- private val srcIds = mutable.ArrayBuilder.make[Int]
+ private val srcIds = mutable.ArrayBuilder.make[ID]
private val dstEncodedIndices = mutable.ArrayBuilder.make[Int]
private val ratings = mutable.ArrayBuilder.make[Float]
@@ -596,12 +615,12 @@ private[recommendation] object ALS extends Logging {
*/
def add(
dstBlockId: Int,
- srcIds: Array[Int],
+ srcIds: Array[ID],
dstLocalIndices: Array[Int],
ratings: Array[Float]): this.type = {
- val sz = srcIds.size
- require(dstLocalIndices.size == sz)
- require(ratings.size == sz)
+ val sz = srcIds.length
+ require(dstLocalIndices.length == sz)
+ require(ratings.length == sz)
this.srcIds ++= srcIds
this.ratings ++= ratings
var j = 0
@@ -613,7 +632,7 @@ private[recommendation] object ALS extends Logging {
}
/** Builds a [[UncompressedInBlock]]. */
- def build(): UncompressedInBlock = {
+ def build(): UncompressedInBlock[ID] = {
new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result())
}
}
@@ -621,24 +640,25 @@ private[recommendation] object ALS extends Logging {
/**
* A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays.
*/
- private[recommendation] class UncompressedInBlock(
- val srcIds: Array[Int],
+ private[recommendation] class UncompressedInBlock[@specialized(Int, Long) ID: ClassTag](
+ val srcIds: Array[ID],
val dstEncodedIndices: Array[Int],
- val ratings: Array[Float]) {
+ val ratings: Array[Float])(
+ implicit ord: Ordering[ID]) {
/** Size the of block. */
- def size: Int = srcIds.size
+ def length: Int = srcIds.length
/**
* Compresses the block into an [[InBlock]]. The algorithm is the same as converting a
* sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format.
* Sorting is done using Spark's built-in Timsort to avoid generating too many objects.
*/
- def compress(): InBlock = {
- val sz = size
+ def compress(): InBlock[ID] = {
+ val sz = length
assert(sz > 0, "Empty in-link block should not exist.")
sort()
- val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[Int]
+ val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[ID]
val dstCountsBuilder = mutable.ArrayBuilder.make[Int]
var preSrcId = srcIds(0)
uniqueSrcIdsBuilder += preSrcId
@@ -659,7 +679,7 @@ private[recommendation] object ALS extends Logging {
}
dstCountsBuilder += curCount
val uniqueSrcIds = uniqueSrcIdsBuilder.result()
- val numUniqueSrdIds = uniqueSrcIds.size
+ val numUniqueSrdIds = uniqueSrcIds.length
val dstCounts = dstCountsBuilder.result()
val dstPtrs = new Array[Int](numUniqueSrdIds + 1)
var sum = 0
@@ -673,51 +693,61 @@ private[recommendation] object ALS extends Logging {
}
private def sort(): Unit = {
- val sz = size
+ val sz = length
// Since there might be interleaved log messages, we insert a unique id for easy pairing.
val sortId = Utils.random.nextInt()
logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)")
val start = System.nanoTime()
- val sorter = new Sorter(new UncompressedInBlockSort)
- sorter.sort(this, 0, size, Ordering[IntWrapper])
+ val sorter = new Sorter(new UncompressedInBlockSort[ID])
+ sorter.sort(this, 0, length, Ordering[KeyWrapper[ID]])
val duration = (System.nanoTime() - start) / 1e9
logDebug(s"Sorting took $duration seconds. (sortId = $sortId)")
}
}
/**
- * A wrapper that holds a primitive integer key.
+ * A wrapper that holds a primitive key.
*
* @see [[UncompressedInBlockSort]]
*/
- private class IntWrapper(var key: Int = 0) extends Ordered[IntWrapper] {
- override def compare(that: IntWrapper): Int = {
- key.compare(that.key)
+ private class KeyWrapper[@specialized(Int, Long) ID: ClassTag](
+ implicit ord: Ordering[ID]) extends Ordered[KeyWrapper[ID]] {
+
+ var key: ID = _
+
+ override def compare(that: KeyWrapper[ID]): Int = {
+ ord.compare(key, that.key)
+ }
+
+ def setKey(key: ID): this.type = {
+ this.key = key
+ this
}
}
/**
* [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]].
*/
- private class UncompressedInBlockSort extends SortDataFormat[IntWrapper, UncompressedInBlock] {
+ private class UncompressedInBlockSort[@specialized(Int, Long) ID: ClassTag](
+ implicit ord: Ordering[ID])
+ extends SortDataFormat[KeyWrapper[ID], UncompressedInBlock[ID]] {
- override def newKey(): IntWrapper = new IntWrapper()
+ override def newKey(): KeyWrapper[ID] = new KeyWrapper()
override def getKey(
- data: UncompressedInBlock,
+ data: UncompressedInBlock[ID],
pos: Int,
- reuse: IntWrapper): IntWrapper = {
+ reuse: KeyWrapper[ID]): KeyWrapper[ID] = {
if (reuse == null) {
- new IntWrapper(data.srcIds(pos))
+ new KeyWrapper().setKey(data.srcIds(pos))
} else {
- reuse.key = data.srcIds(pos)
- reuse
+ reuse.setKey(data.srcIds(pos))
}
}
override def getKey(
- data: UncompressedInBlock,
- pos: Int): IntWrapper = {
+ data: UncompressedInBlock[ID],
+ pos: Int): KeyWrapper[ID] = {
getKey(data, pos, null)
}
@@ -730,16 +760,16 @@ private[recommendation] object ALS extends Logging {
data(pos1) = tmp
}
- override def swap(data: UncompressedInBlock, pos0: Int, pos1: Int): Unit = {
+ override def swap(data: UncompressedInBlock[ID], pos0: Int, pos1: Int): Unit = {
swapElements(data.srcIds, pos0, pos1)
swapElements(data.dstEncodedIndices, pos0, pos1)
swapElements(data.ratings, pos0, pos1)
}
override def copyRange(
- src: UncompressedInBlock,
+ src: UncompressedInBlock[ID],
srcPos: Int,
- dst: UncompressedInBlock,
+ dst: UncompressedInBlock[ID],
dstPos: Int,
length: Int): Unit = {
System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length)
@@ -747,15 +777,15 @@ private[recommendation] object ALS extends Logging {
System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length)
}
- override def allocate(length: Int): UncompressedInBlock = {
+ override def allocate(length: Int): UncompressedInBlock[ID] = {
new UncompressedInBlock(
- new Array[Int](length), new Array[Int](length), new Array[Float](length))
+ new Array[ID](length), new Array[Int](length), new Array[Float](length))
}
override def copyElement(
- src: UncompressedInBlock,
+ src: UncompressedInBlock[ID],
srcPos: Int,
- dst: UncompressedInBlock,
+ dst: UncompressedInBlock[ID],
dstPos: Int): Unit = {
dst.srcIds(dstPos) = src.srcIds(srcPos)
dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos)
@@ -771,19 +801,20 @@ private[recommendation] object ALS extends Logging {
* @param dstPart partitioner for dst IDs
* @return (in-blocks, out-blocks)
*/
- private def makeBlocks(
+ private def makeBlocks[ID: ClassTag](
prefix: String,
- ratingBlocks: RDD[((Int, Int), RatingBlock)],
+ ratingBlocks: RDD[((Int, Int), RatingBlock[ID])],
srcPart: Partitioner,
- dstPart: Partitioner): (RDD[(Int, InBlock)], RDD[(Int, OutBlock)]) = {
+ dstPart: Partitioner)(
+ implicit srcOrd: Ordering[ID]): (RDD[(Int, InBlock[ID])], RDD[(Int, OutBlock)]) = {
val inBlocks = ratingBlocks.map {
case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) =>
// The implementation is a faster version of
// val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap
val start = System.nanoTime()
- val dstIdSet = new OpenHashSet[Int](1 << 20)
+ val dstIdSet = new OpenHashSet[ID](1 << 20)
dstIds.foreach(dstIdSet.add)
- val sortedDstIds = new Array[Int](dstIdSet.size)
+ val sortedDstIds = new Array[ID](dstIdSet.size)
var i = 0
var pos = dstIdSet.nextPos(0)
while (pos != -1) {
@@ -792,10 +823,10 @@ private[recommendation] object ALS extends Logging {
i += 1
}
assert(i == dstIdSet.size)
- ju.Arrays.sort(sortedDstIds)
- val dstIdToLocalIndex = new OpenHashMap[Int, Int](sortedDstIds.size)
+ Sorting.quickSort(sortedDstIds)
+ val dstIdToLocalIndex = new OpenHashMap[ID, Int](sortedDstIds.length)
i = 0
- while (i < sortedDstIds.size) {
+ while (i < sortedDstIds.length) {
dstIdToLocalIndex.update(sortedDstIds(i), i)
i += 1
}
@@ -806,7 +837,7 @@ private[recommendation] object ALS extends Logging {
}.groupByKey(new HashPartitioner(srcPart.numPartitions))
.mapValues { iter =>
val builder =
- new UncompressedInBlockBuilder(new LocalIndexEncoder(dstPart.numPartitions))
+ new UncompressedInBlockBuilder[ID](new LocalIndexEncoder(dstPart.numPartitions))
iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) =>
builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
}
@@ -817,7 +848,7 @@ private[recommendation] object ALS extends Logging {
val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int])
var i = 0
val seen = new Array[Boolean](dstPart.numPartitions)
- while (i < srcIds.size) {
+ while (i < srcIds.length) {
var j = dstPtrs(i)
ju.Arrays.fill(seen, false)
while (j < dstPtrs(i + 1)) {
@@ -851,16 +882,16 @@ private[recommendation] object ALS extends Logging {
*
* @return dst factors
*/
- private def computeFactors(
+ private def computeFactors[ID](
srcFactorBlocks: RDD[(Int, FactorBlock)],
srcOutBlocks: RDD[(Int, OutBlock)],
- dstInBlocks: RDD[(Int, InBlock)],
+ dstInBlocks: RDD[(Int, InBlock[ID])],
rank: Int,
regParam: Double,
srcEncoder: LocalIndexEncoder,
implicitPrefs: Boolean = false,
alpha: Double = 1.0): RDD[(Int, FactorBlock)] = {
- val numSrcBlocks = srcFactorBlocks.partitions.size
+ val numSrcBlocks = srcFactorBlocks.partitions.length
val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
case (srcBlockId, (srcOutBlock, srcFactors)) =>
@@ -868,18 +899,18 @@ private[recommendation] object ALS extends Logging {
(dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx))))
}
}
- val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.size))
+ val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.length))
dstInBlocks.join(merged).mapValues {
case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) =>
val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks)
srcFactors.foreach { case (srcBlockId, factors) =>
sortedSrcFactors(srcBlockId) = factors
}
- val dstFactors = new Array[Array[Float]](dstIds.size)
+ val dstFactors = new Array[Array[Float]](dstIds.length)
var j = 0
val ls = new NormalEquation(rank)
val solver = new CholeskySolver // TODO: add NNLS solver
- while (j < dstIds.size) {
+ while (j < dstIds.length) {
ls.reset()
if (implicitPrefs) {
ls.merge(YtY.get)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 9da253c61d..07aff56fb7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -155,7 +155,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
}
test("RatingBlockBuilder") {
- val emptyBuilder = new RatingBlockBuilder()
+ val emptyBuilder = new RatingBlockBuilder[Int]()
assert(emptyBuilder.size === 0)
val emptyBlock = emptyBuilder.build()
assert(emptyBlock.srcIds.isEmpty)
@@ -179,12 +179,12 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
test("UncompressedInBlock") {
val encoder = new LocalIndexEncoder(10)
- val uncompressed = new UncompressedInBlockBuilder(encoder)
+ val uncompressed = new UncompressedInBlockBuilder[Int](encoder)
.add(0, Array(1, 0, 2), Array(0, 1, 4), Array(1.0f, 2.0f, 3.0f))
.add(1, Array(3, 0), Array(2, 5), Array(4.0f, 5.0f))
.build()
- assert(uncompressed.size === 5)
- val records = Seq.tabulate(uncompressed.size) { i =>
+ assert(uncompressed.length === 5)
+ val records = Seq.tabulate(uncompressed.length) { i =>
val dstEncodedIndex = uncompressed.dstEncodedIndices(i)
val dstBlockId = encoder.blockId(dstEncodedIndex)
val dstLocalIndex = encoder.localIndex(dstEncodedIndex)
@@ -228,15 +228,15 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
numItems: Int,
rank: Int,
noiseStd: Double = 0.0,
- seed: Long = 11L): (RDD[Rating], RDD[Rating]) = {
+ seed: Long = 11L): (RDD[Rating[Int]], RDD[Rating[Int]]) = {
val trainingFraction = 0.6
val testFraction = 0.3
val totalFraction = trainingFraction + testFraction
val random = new Random(seed)
val userFactors = genFactors(numUsers, rank, random)
val itemFactors = genFactors(numItems, rank, random)
- val training = ArrayBuffer.empty[Rating]
- val test = ArrayBuffer.empty[Rating]
+ val training = ArrayBuffer.empty[Rating[Int]]
+ val test = ArrayBuffer.empty[Rating[Int]]
for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
val x = random.nextDouble()
if (x < totalFraction) {
@@ -268,7 +268,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
numItems: Int,
rank: Int,
noiseStd: Double = 0.0,
- seed: Long = 11L): (RDD[Rating], RDD[Rating]) = {
+ seed: Long = 11L): (RDD[Rating[Int]], RDD[Rating[Int]]) = {
// The assumption of the implicit feedback model is that unobserved ratings are more likely to
// be negatives.
val positiveFraction = 0.8
@@ -279,8 +279,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
val random = new Random(seed)
val userFactors = genFactors(numUsers, rank, random)
val itemFactors = genFactors(numItems, rank, random)
- val training = ArrayBuffer.empty[Rating]
- val test = ArrayBuffer.empty[Rating]
+ val training = ArrayBuffer.empty[Rating[Int]]
+ val test = ArrayBuffer.empty[Rating[Int]]
for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1)
val threshold = if (rating > 0) positiveFraction else negativeFraction
@@ -340,8 +340,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
* @param targetRMSE target test RMSE
*/
def testALS(
- training: RDD[Rating],
- test: RDD[Rating],
+ training: RDD[Rating[Int]],
+ test: RDD[Rating[Int]],
rank: Int,
maxIter: Int,
regParam: Double,
@@ -432,4 +432,16 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, implicitPrefs = true,
targetRMSE = 0.3)
}
+
+ test("using generic ID types") {
+ val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
+
+ val longRatings = ratings.map(r => Rating(r.user.toLong, r.item.toLong, r.rating))
+ val (longUserFactors, _) = ALS.train(longRatings, rank = 2, maxIter = 4)
+ assert(longUserFactors.first()._1.getClass === classOf[Long])
+
+ val strRatings = ratings.map(r => Rating(r.user.toString, r.item.toString, r.rating))
+ val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4)
+ assert(strUserFactors.first()._1.getClass === classOf[String])
+ }
}