diff options
2 files changed, 3 insertions, 81 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index ed2f8b41bc..969e23be21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -21,45 +21,23 @@ import java.lang.{Integer => JavaInteger} import org.jblas.DoubleMatrix -import org.apache.spark.Logging +import org.apache.spark.SparkContext._ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel /** * Model representing the result of matrix factorization. * - * Note: If you create the model directly using constructor, please be aware that fast prediction - * requires cached user/product features and their associated partitioners. - * * @param rank Rank for the features in this model. * @param userFeatures RDD of tuples where each tuple represents the userId and * the features computed for this user. * @param productFeatures RDD of tuples where each tuple represents the productId * and the features computed for this product. */ -class MatrixFactorizationModel( +class MatrixFactorizationModel private[mllib] ( val rank: Int, val userFeatures: RDD[(Int, Array[Double])], - val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging { - - require(rank > 0) - validateFeatures("User", userFeatures) - validateFeatures("Product", productFeatures) - - /** Validates factors and warns users if there are performance concerns. */ - private def validateFeatures(name: String, features: RDD[(Int, Array[Double])]): Unit = { - require(features.first()._2.size == rank, - s"$name feature dimension does not match the rank $rank.") - if (features.partitioner.isEmpty) { - logWarning(s"$name factor does not have a partitioner. " - + "Prediction on individual records could be slow.") - } - if (features.getStorageLevel == StorageLevel.NONE) { - logWarning(s"$name factor is not cached. Prediction could be slow.") - } - } - + val productFeatures: RDD[(Int, Array[Double])]) extends Serializable { /** Predict the rating of one user for one product. */ def predict(user: Int, product: Int): Double = { val userVector = new DoubleMatrix(userFeatures.lookup(user).head) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala deleted file mode 100644 index b9caecc904..0000000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.recommendation - -import org.scalatest.FunSuite - -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.rdd.RDD - -class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { - - val rank = 2 - var userFeatures: RDD[(Int, Array[Double])] = _ - var prodFeatures: RDD[(Int, Array[Double])] = _ - - override def beforeAll(): Unit = { - super.beforeAll() - userFeatures = sc.parallelize(Seq((0, Array(1.0, 2.0)), (1, Array(3.0, 4.0)))) - prodFeatures = sc.parallelize(Seq((2, Array(5.0, 6.0)))) - } - - test("constructor") { - val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) - assert(model.predict(0, 2) ~== 17.0 relTol 1e-14) - - intercept[IllegalArgumentException] { - new MatrixFactorizationModel(1, userFeatures, prodFeatures) - } - - val userFeatures1 = sc.parallelize(Seq((0, Array(1.0)), (1, Array(3.0)))) - intercept[IllegalArgumentException] { - new MatrixFactorizationModel(rank, userFeatures1, prodFeatures) - } - - val prodFeatures1 = sc.parallelize(Seq((2, Array(5.0)))) - intercept[IllegalArgumentException] { - new MatrixFactorizationModel(rank, userFeatures, prodFeatures1) - } - } -} |