aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala11
1 files changed, 11 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 07aff56fb7..ee08c3c327 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -444,4 +444,15 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4)
assert(strUserFactors.first()._1.getClass === classOf[String])
}
+
+ test("nonnegative constraint") {
+ val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
+ val (userFactors, itemFactors) = ALS.train(ratings, rank = 2, maxIter = 4, nonnegative = true)
+ def isNonnegative(factors: RDD[(Int, Array[Float])]): Boolean = {
+ factors.values.map { _.forall(_ >= 0.0) }.reduce(_ && _)
+ }
+ assert(isNonnegative(userFactors))
+ assert(isNonnegative(itemFactors))
+ // TODO: Validate the solution.
+ }
}