aboutsummaryrefslogtreecommitdiff
path: root/graphx
diff options
context:
space:
mode:
authorSean Owen <sowen@cloudera.com>2015-02-15 20:41:27 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-15 20:41:27 -0800
commitacf2558dc92901c342262c35eebb95f2a9b7a9ae (patch)
tree169f836c34f295bc525fc43c167c79f4404ed814 /graphx
parentcd4a15366244657c4b7936abe5054754534366f2 (diff)
downloadspark-acf2558dc92901c342262c35eebb95f2a9b7a9ae.tar.gz
spark-acf2558dc92901c342262c35eebb95f2a9b7a9ae.tar.bz2
spark-acf2558dc92901c342262c35eebb95f2a9b7a9ae.zip
SPARK-5815 [MLLIB] Deprecate SVDPlusPlus APIs that expose DoubleMatrix from JBLAS
Deprecate SVDPlusPlus.run and introduce SVDPlusPlus.runSVDPlusPlus with return type that doesn't include DoubleMatrix CC mengxr Author: Sean Owen <sowen@cloudera.com> Closes #4614 from srowen/SPARK-5815 and squashes the following commits: 288cb05 [Sean Owen] Clarify deprecation plans in scaladoc 497458e [Sean Owen] Deprecate SVDPlusPlus.run and introduce SVDPlusPlus.runSVDPlusPlus with return type that doesn't include DoubleMatrix
Diffstat (limited to 'graphx')
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala25
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala2
2 files changed, 26 insertions, 1 deletions
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
index 112ed09ef4..fc84cfbe64 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
@@ -17,6 +17,8 @@
package org.apache.spark.graphx.lib
+import org.apache.spark.annotation.Experimental
+
import scala.util.Random
import org.jblas.DoubleMatrix
import org.apache.spark.rdd._
@@ -38,6 +40,8 @@ object SVDPlusPlus {
extends Serializable
/**
+ * :: Experimental ::
+ *
* Implement SVD++ based on "Factorization Meets the Neighborhood:
* a Multifaceted Collaborative Filtering Model",
* available at [[http://public.research.att.com/~volinsky/netflix/kdd08koren.pdf]].
@@ -45,12 +49,33 @@ object SVDPlusPlus {
* The prediction rule is rui = u + bu + bi + qi*(pu + |N(u)|^(-0.5)*sum(y)),
* see the details on page 6.
*
+ * This method temporarily replaces `run()`, and replaces `DoubleMatrix` in `run()`'s return
+ * value with `Array[Double]`. In 1.4.0, this method will be deprecated, but will be copied
+ * to replace `run()`, which will then be undeprecated.
+ *
* @param edges edges for constructing the graph
*
* @param conf SVDPlusPlus parameters
*
* @return a graph with vertex attributes containing the trained model
*/
+ @Experimental
+ def runSVDPlusPlus(edges: RDD[Edge[Double]], conf: Conf)
+ : (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) =
+ {
+ val (graph, u) = run(edges, conf)
+ // Convert DoubleMatrix to Array[Double]:
+ val newVertices = graph.vertices.mapValues(v => (v._1.toArray, v._2.toArray, v._3, v._4))
+ (Graph(newVertices, graph.edges), u)
+ }
+
+ /**
+ * This method is deprecated in favor of `runSVDPlusPlus()`, which replaces `DoubleMatrix`
+ * with `Array[Double]` in its return value. This method is deprecated. It will effectively
+ * be removed in 1.4.0 when `runSVDPlusPlus()` is copied to replace `run()`, and hence the
+ * return type of this method changes.
+ */
+ @deprecated("Call runSVDPlusPlus", "1.3.0")
def run(edges: RDD[Edge[Double]], conf: Conf)
: (Graph[(DoubleMatrix, DoubleMatrix, Double, Double), Double], Double) =
{
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
index e01df56e94..9987a4b1a3 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
@@ -32,7 +32,7 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext {
Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble)
}
val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations
- var (graph, u) = SVDPlusPlus.run(edges, conf)
+ var (graph, u) = SVDPlusPlus.runSVDPlusPlus(edges, conf)
graph.cache()
val err = graph.vertices.collect().map{ case (vid, vd) =>
if (vid % 2 == 1) vd._4 else 0.0