aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-25 20:11:40 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-25 20:11:40 -0800
commitb5fb1410c5eed1156decb4f9fcc22436a658ce4d (patch)
tree792e1ab1ce54f4665424b13c0aa381588656b60e /mllib
parent4d95526a75ad1630554683fe7a7e583da44ba6e4 (diff)
downloadspark-b5fb1410c5eed1156decb4f9fcc22436a658ce4d.tar.gz
spark-b5fb1410c5eed1156decb4f9fcc22436a658ce4d.tar.bz2
spark-b5fb1410c5eed1156decb4f9fcc22436a658ce4d.zip
[SPARK-4604][MLLIB] make MatrixFactorizationModel public
User could construct an MF model directly. I added a note about the performance. Author: Xiangrui Meng <meng@databricks.com> Closes #3459 from mengxr/SPARK-4604 and squashes the following commits: f64bcd3 [Xiangrui Meng] organize imports ed08214 [Xiangrui Meng] check preconditions and unit tests a624c12 [Xiangrui Meng] make MatrixFactorizationModel public
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala28
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala56
2 files changed, 81 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 969e23be21..ed2f8b41bc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -21,23 +21,45 @@ import java.lang.{Integer => JavaInteger}
import org.jblas.DoubleMatrix
-import org.apache.spark.SparkContext._
+import org.apache.spark.Logging
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
/**
* Model representing the result of matrix factorization.
*
+ * Note: If you create the model directly using constructor, please be aware that fast prediction
+ * requires cached user/product features and their associated partitioners.
+ *
* @param rank Rank for the features in this model.
* @param userFeatures RDD of tuples where each tuple represents the userId and
* the features computed for this user.
* @param productFeatures RDD of tuples where each tuple represents the productId
* and the features computed for this product.
*/
-class MatrixFactorizationModel private[mllib] (
+class MatrixFactorizationModel(
val rank: Int,
val userFeatures: RDD[(Int, Array[Double])],
- val productFeatures: RDD[(Int, Array[Double])]) extends Serializable {
+ val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging {
+
+ require(rank > 0)
+ validateFeatures("User", userFeatures)
+ validateFeatures("Product", productFeatures)
+
+ /** Validates factors and warns users if there are performance concerns. */
+ private def validateFeatures(name: String, features: RDD[(Int, Array[Double])]): Unit = {
+ require(features.first()._2.size == rank,
+ s"$name feature dimension does not match the rank $rank.")
+ if (features.partitioner.isEmpty) {
+ logWarning(s"$name factor does not have a partitioner. "
+ + "Prediction on individual records could be slow.")
+ }
+ if (features.getStorageLevel == StorageLevel.NONE) {
+ logWarning(s"$name factor is not cached. Prediction could be slow.")
+ }
+ }
+
/** Predict the rating of one user for one product. */
def predict(user: Int, product: Int): Double = {
val userVector = new DoubleMatrix(userFeatures.lookup(user).head)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
new file mode 100644
index 0000000000..b9caecc904
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.recommendation
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.rdd.RDD
+
+class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {
+
+ val rank = 2
+ var userFeatures: RDD[(Int, Array[Double])] = _
+ var prodFeatures: RDD[(Int, Array[Double])] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ userFeatures = sc.parallelize(Seq((0, Array(1.0, 2.0)), (1, Array(3.0, 4.0))))
+ prodFeatures = sc.parallelize(Seq((2, Array(5.0, 6.0))))
+ }
+
+ test("constructor") {
+ val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
+ assert(model.predict(0, 2) ~== 17.0 relTol 1e-14)
+
+ intercept[IllegalArgumentException] {
+ new MatrixFactorizationModel(1, userFeatures, prodFeatures)
+ }
+
+ val userFeatures1 = sc.parallelize(Seq((0, Array(1.0)), (1, Array(3.0))))
+ intercept[IllegalArgumentException] {
+ new MatrixFactorizationModel(rank, userFeatures1, prodFeatures)
+ }
+
+ val prodFeatures1 = sc.parallelize(Seq((2, Array(5.0))))
+ intercept[IllegalArgumentException] {
+ new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
+ }
+ }
+}