diff options
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 07aff56fb7..ee08c3c327 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -444,4 +444,15 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4) assert(strUserFactors.first()._1.getClass === classOf[String]) } + + test("nonnegative constraint") { + val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + val (userFactors, itemFactors) = ALS.train(ratings, rank = 2, maxIter = 4, nonnegative = true) + def isNonnegative(factors: RDD[(Int, Array[Float])]): Boolean = { + factors.values.map { _.forall(_ >= 0.0) }.reduce(_ && _) + } + assert(isNonnegative(userFactors)) + assert(isNonnegative(itemFactors)) + // TODO: Validate the solution. + } } |