aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala28
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala56
2 files changed, 3 insertions, 81 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index ed2f8b41bc..969e23be21 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -21,45 +21,23 @@ import java.lang.{Integer => JavaInteger}
import org.jblas.DoubleMatrix
-import org.apache.spark.Logging
+import org.apache.spark.SparkContext._
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
/**
* Model representing the result of matrix factorization.
*
- * Note: If you create the model directly using constructor, please be aware that fast prediction
- * requires cached user/product features and their associated partitioners.
- *
* @param rank Rank for the features in this model.
* @param userFeatures RDD of tuples where each tuple represents the userId and
* the features computed for this user.
* @param productFeatures RDD of tuples where each tuple represents the productId
* and the features computed for this product.
*/
-class MatrixFactorizationModel(
+class MatrixFactorizationModel private[mllib] (
val rank: Int,
val userFeatures: RDD[(Int, Array[Double])],
- val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging {
-
- require(rank > 0)
- validateFeatures("User", userFeatures)
- validateFeatures("Product", productFeatures)
-
- /** Validates factors and warns users if there are performance concerns. */
- private def validateFeatures(name: String, features: RDD[(Int, Array[Double])]): Unit = {
- require(features.first()._2.size == rank,
- s"$name feature dimension does not match the rank $rank.")
- if (features.partitioner.isEmpty) {
- logWarning(s"$name factor does not have a partitioner. "
- + "Prediction on individual records could be slow.")
- }
- if (features.getStorageLevel == StorageLevel.NONE) {
- logWarning(s"$name factor is not cached. Prediction could be slow.")
- }
- }
-
+ val productFeatures: RDD[(Int, Array[Double])]) extends Serializable {
/** Predict the rating of one user for one product. */
def predict(user: Int, product: Int): Double = {
val userVector = new DoubleMatrix(userFeatures.lookup(user).head)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
deleted file mode 100644
index b9caecc904..0000000000
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.recommendation
-
-import org.scalatest.FunSuite
-
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.rdd.RDD
-
-class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {
-
- val rank = 2
- var userFeatures: RDD[(Int, Array[Double])] = _
- var prodFeatures: RDD[(Int, Array[Double])] = _
-
- override def beforeAll(): Unit = {
- super.beforeAll()
- userFeatures = sc.parallelize(Seq((0, Array(1.0, 2.0)), (1, Array(3.0, 4.0))))
- prodFeatures = sc.parallelize(Seq((2, Array(5.0, 6.0))))
- }
-
- test("constructor") {
- val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
- assert(model.predict(0, 2) ~== 17.0 relTol 1e-14)
-
- intercept[IllegalArgumentException] {
- new MatrixFactorizationModel(1, userFeatures, prodFeatures)
- }
-
- val userFeatures1 = sc.parallelize(Seq((0, Array(1.0)), (1, Array(3.0))))
- intercept[IllegalArgumentException] {
- new MatrixFactorizationModel(rank, userFeatures1, prodFeatures)
- }
-
- val prodFeatures1 = sc.parallelize(Seq((2, Array(5.0))))
- intercept[IllegalArgumentException] {
- new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
- }
- }
-}