aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-02-02 15:55:44 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-02 15:55:44 -0800
commit46d50f151c02c6892fc84a37fdf2a521dc774d1c (patch)
tree5788c15d5f400f97909ee5bb423d1ba9c20ac153 /mllib/src
parent1646f89d967913ee1f231d9606f8502d13c25804 (diff)
downloadspark-46d50f151c02c6892fc84a37fdf2a521dc774d1c.tar.gz
spark-46d50f151c02c6892fc84a37fdf2a521dc774d1c.tar.bz2
spark-46d50f151c02c6892fc84a37fdf2a521dc774d1c.zip
[SPARK-5513][MLLIB] Add nonnegative option to ml's ALS
This PR ports the NNLS solver to the new ALS implementation. CC: coderxiang Author: Xiangrui Meng <meng@databricks.com> Closes #4302 from mengxr/SPARK-5513 and squashes the following commits: 4cbdab0 [Xiangrui Meng] fix serialization 88de634 [Xiangrui Meng] add NNLS to ml's ALS
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala95
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala11
3 files changed, 96 insertions, 14 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 979a19d3b2..82d21d5e4c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -25,12 +25,14 @@ import scala.util.Sorting
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
+import org.jblas.DoubleMatrix
import org.netlib.util.intW
import org.apache.spark.{HashPartitioner, Logging, Partitioner}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
+import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Dsl._
@@ -80,6 +82,10 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating"))
def getRatingCol: String = get(ratingCol)
+ val nonnegative = new BooleanParam(
+ this, "nonnegative", "whether to use nonnegative constraint for least squares", Some(false))
+ val getNonnegative: Boolean = get(nonnegative)
+
/**
* Validates and transforms the input schema.
* @param schema input schema
@@ -186,6 +192,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
def setPredictionCol(value: String): this.type = set(predictionCol, value)
def setMaxIter(value: Int): this.type = set(maxIter, value)
def setRegParam(value: Double): this.type = set(regParam, value)
+ def setNonnegative(value: Boolean): this.type = set(nonnegative, value)
/** Sets both numUserBlocks and numItemBlocks to the specific value. */
def setNumBlocks(value: Int): this.type = {
@@ -207,7 +214,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),
- alpha = map(alpha))
+ alpha = map(alpha), nonnegative = map(nonnegative))
val model = new ALSModel(this, map, map(rank), userFactors, itemFactors)
Params.inheritValues(map, this, model)
model
@@ -232,11 +239,16 @@ object ALS extends Logging {
/** Rating class for better code readability. */
case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float)
+ /** Trait for least squares solvers applied to the normal equation. */
+ private[recommendation] trait LeastSquaresNESolver extends Serializable {
+ /** Solves a least squares problem (possibly with other constraints). */
+ def solve(ne: NormalEquation, lambda: Double): Array[Float]
+ }
+
/** Cholesky solver for least square problems. */
- private[recommendation] class CholeskySolver {
+ private[recommendation] class CholeskySolver extends LeastSquaresNESolver {
private val upper = "U"
- private val info = new intW(0)
/**
* Solves a least squares problem with L2 regularization:
@@ -247,7 +259,7 @@ object ALS extends Logging {
* @param lambda regularization constant, which will be scaled by n
* @return the solution x
*/
- def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
+ override def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
val k = ne.k
// Add scaled lambda to the diagonals of AtA.
val scaledlambda = lambda * ne.n
@@ -258,6 +270,7 @@ object ALS extends Logging {
i += j
j += 1
}
+ val info = new intW(0)
lapack.dppsv(upper, k, 1, ne.ata, ne.atb, k, info)
val code = info.`val`
assert(code == 0, s"lapack.dppsv returned $code.")
@@ -272,6 +285,63 @@ object ALS extends Logging {
}
}
+ /** NNLS solver. */
+ private[recommendation] class NNLSSolver extends LeastSquaresNESolver {
+ private var rank: Int = -1
+ private var workspace: NNLS.Workspace = _
+ private var ata: DoubleMatrix = _
+ private var initialized: Boolean = false
+
+ private def initialize(rank: Int): Unit = {
+ if (!initialized) {
+ this.rank = rank
+ workspace = NNLS.createWorkspace(rank)
+ ata = new DoubleMatrix(rank, rank)
+ initialized = true
+ } else {
+ require(this.rank == rank)
+ }
+ }
+
+ /**
+ * Solves a nonnegative least squares problem with L2 regularizatin:
+ *
+ * min_x_ norm(A x - b)^2^ + lambda * n * norm(x)^2^
+ * subject to x >= 0
+ */
+ override def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
+ val rank = ne.k
+ initialize(rank)
+ fillAtA(ne.ata, lambda * ne.n)
+ val x = NNLS.solve(ata, new DoubleMatrix(rank, 1, ne.atb: _*), workspace)
+ ne.reset()
+ x.map(x => x.toFloat)
+ }
+
+ /**
+ * Given a triangular matrix in the order of fillXtX above, compute the full symmetric square
+ * matrix that it represents, storing it into destMatrix.
+ */
+ private def fillAtA(triAtA: Array[Double], lambda: Double) {
+ var i = 0
+ var pos = 0
+ var a = 0.0
+ val data = ata.data
+ while (i < rank) {
+ var j = 0
+ while (j <= i) {
+ a = triAtA(pos)
+ data(i * rank + j) = a
+ data(j * rank + i) = a
+ pos += 1
+ j += 1
+ }
+ data(i * rank + i) += lambda
+ i += 1
+ }
+ }
+ }
+
/** Representing a normal equation (ALS' subproblem). */
private[recommendation] class NormalEquation(val k: Int) extends Serializable {
@@ -350,12 +420,14 @@ object ALS extends Logging {
maxIter: Int = 10,
regParam: Double = 1.0,
implicitPrefs: Boolean = false,
- alpha: Double = 1.0)(
+ alpha: Double = 1.0,
+ nonnegative: Boolean = false)(
implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
val userPart = new HashPartitioner(numUserBlocks)
val itemPart = new HashPartitioner(numItemBlocks)
val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions)
val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions)
+ val solver = if (nonnegative) new NNLSSolver else new CholeskySolver
val blockRatings = partitionRatings(ratings, userPart, itemPart).cache()
val (userInBlocks, userOutBlocks) = makeBlocks("user", blockRatings, userPart, itemPart)
// materialize blockRatings and user blocks
@@ -374,20 +446,20 @@ object ALS extends Logging {
userFactors.setName(s"userFactors-$iter").persist()
val previousItemFactors = itemFactors
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
- userLocalIndexEncoder, implicitPrefs, alpha)
+ userLocalIndexEncoder, implicitPrefs, alpha, solver)
previousItemFactors.unpersist()
itemFactors.setName(s"itemFactors-$iter").persist()
val previousUserFactors = userFactors
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
- itemLocalIndexEncoder, implicitPrefs, alpha)
+ itemLocalIndexEncoder, implicitPrefs, alpha, solver)
previousUserFactors.unpersist()
}
} else {
for (iter <- 0 until maxIter) {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
- userLocalIndexEncoder)
+ userLocalIndexEncoder, solver = solver)
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
- itemLocalIndexEncoder)
+ itemLocalIndexEncoder, solver = solver)
}
}
val userIdAndFactors = userInBlocks
@@ -879,6 +951,7 @@ object ALS extends Logging {
* @param srcEncoder encoder for src local indices
* @param implicitPrefs whether to use implicit preference
* @param alpha the alpha constant in the implicit preference formulation
+ * @param solver solver for least squares problems
*
* @return dst factors
*/
@@ -890,7 +963,8 @@ object ALS extends Logging {
regParam: Double,
srcEncoder: LocalIndexEncoder,
implicitPrefs: Boolean = false,
- alpha: Double = 1.0): RDD[(Int, FactorBlock)] = {
+ alpha: Double = 1.0,
+ solver: LeastSquaresNESolver): RDD[(Int, FactorBlock)] = {
val numSrcBlocks = srcFactorBlocks.partitions.length
val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
@@ -909,7 +983,6 @@ object ALS extends Logging {
val dstFactors = new Array[Array[Float]](dstIds.length)
var j = 0
val ls = new NormalEquation(rank)
- val solver = new CholeskySolver // TODO: add NNLS solver
while (j < dstIds.length) {
ls.reset()
if (implicitPrefs) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
index fef062e02b..ccd93b318b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
@@ -19,13 +19,11 @@ package org.apache.spark.mllib.optimization
import org.jblas.{DoubleMatrix, SimpleBlas}
-import org.apache.spark.annotation.DeveloperApi
-
/**
* Object used to solve nonnegative least squares problems using a modified
* projected gradient method.
*/
-private[mllib] object NNLS {
+private[spark] object NNLS {
class Workspace(val n: Int) {
val scratch = new DoubleMatrix(n, 1)
val grad = new DoubleMatrix(n, 1)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 07aff56fb7..ee08c3c327 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -444,4 +444,15 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4)
assert(strUserFactors.first()._1.getClass === classOf[String])
}
+
+ test("nonnegative constraint") {
+ val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
+ val (userFactors, itemFactors) = ALS.train(ratings, rank = 2, maxIter = 4, nonnegative = true)
+ def isNonnegative(factors: RDD[(Int, Array[Float])]): Boolean = {
+ factors.values.map { _.forall(_ >= 0.0) }.reduce(_ && _)
+ }
+ assert(isNonnegative(userFactors))
+ assert(isNonnegative(itemFactors))
+ // TODO: Validate the solution.
+ }
}