aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-04-13 11:53:17 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-13 11:53:17 -0700
commit1e340c3ae4d5361d048a3d6990f144cfc923666f (patch)
tree81c481ecc9b909c252dc8f57e540b1839f193ebd /mllib
parent6cc5b3ed3c0c729f97956fa017d8eb7d6b43f90f (diff)
downloadspark-1e340c3ae4d5361d048a3d6990f144cfc923666f.tar.gz
spark-1e340c3ae4d5361d048a3d6990f144cfc923666f.tar.bz2
spark-1e340c3ae4d5361d048a3d6990f144cfc923666f.zip
[SPARK-5988][MLlib] add save/load for PowerIterationClusteringModel
See JIRA issue [SPARK-5988](https://issues.apache.org/jira/browse/SPARK-5988). Author: Xusen Yin <yinxusen@gmail.com> Closes #5450 from yinxusen/SPARK-5988 and squashes the following commits: cb1ecfa [Xusen Yin] change Assignment into case class b1dd24c [Xusen Yin] add test suite 63c3923 [Xusen Yin] add save load for power iteration clustering
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala68
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala34
2 files changed, 97 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
index 180023922a..aa53e88d59 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
@@ -17,15 +17,20 @@
package org.apache.spark.mllib.clustering
-import org.apache.spark.{Logging, SparkException}
+import org.json4s.JsonDSL._
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.util.random.XORShiftRandom
+import org.apache.spark.{Logging, SparkContext, SparkException}
/**
* :: Experimental ::
@@ -38,7 +43,60 @@ import org.apache.spark.util.random.XORShiftRandom
@Experimental
class PowerIterationClusteringModel(
val k: Int,
- val assignments: RDD[PowerIterationClustering.Assignment]) extends Serializable
+ val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable {
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] {
+ override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
+ PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path)
+ }
+
+ private[clustering]
+ object SaveLoadV1_0 {
+
+ private val thisFormatVersion = "1.0"
+
+ private[clustering]
+ val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel"
+
+ def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ val dataRDD = model.assignments.toDF()
+ dataRDD.saveAsParquetFile(Loader.dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
+ implicit val formats = DefaultFormats
+ val sqlContext = new SQLContext(sc)
+
+ val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+ assert(className == thisClassName)
+ assert(formatVersion == thisFormatVersion)
+
+ val k = (metadata \ "k").extract[Int]
+ val assignments = sqlContext.parquetFile(Loader.dataPath(path))
+ Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema)
+
+ val assignmentsRDD = assignments.map {
+ case Row(id: Long, cluster: Int) => PowerIterationClustering.Assignment(id, cluster)
+ }
+
+ new PowerIterationClusteringModel(k, assignmentsRDD)
+ }
+ }
+}
/**
* :: Experimental ::
@@ -135,7 +193,7 @@ class PowerIterationClustering private[clustering] (
val v = powerIter(w, maxIterations)
val assignments = kMeans(v, k).mapPartitions({ iter =>
iter.map { case (id, cluster) =>
- new Assignment(id, cluster)
+ Assignment(id, cluster)
}
}, preservesPartitioning = true)
new PowerIterationClusteringModel(k, assignments)
@@ -152,7 +210,7 @@ object PowerIterationClustering extends Logging {
* @param cluster assigned cluster id
*/
@Experimental
- class Assignment(val id: Long, val cluster: Int) extends Serializable
+ case class Assignment(id: Long, cluster: Int)
/**
* Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W).
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
index 6315c03a70..6d6fe6fe46 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
@@ -18,12 +18,15 @@
package org.apache.spark.mllib.clustering
import scala.collection.mutable
+import scala.util.Random
import org.scalatest.FunSuite
+import org.apache.spark.SparkContext
import org.apache.spark.graphx.{Edge, Graph}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext {
@@ -110,4 +113,35 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
assert(x ~== u1(i.toInt) absTol 1e-14)
}
}
+
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ val model = PowerIterationClusteringSuite.createModel(sc, 3, 10)
+ try {
+ model.save(sc, path)
+ val sameModel = PowerIterationClusteringModel.load(sc, path)
+ PowerIterationClusteringSuite.checkEqual(model, sameModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+}
+
+object PowerIterationClusteringSuite extends FunSuite {
+ def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = {
+ val assignments = sc.parallelize(
+ (0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k))))
+ new PowerIterationClusteringModel(k, assignments)
+ }
+
+ def checkEqual(a: PowerIterationClusteringModel, b: PowerIterationClusteringModel): Unit = {
+ assert(a.k === b.k)
+
+ val aAssignments = a.assignments.map(x => (x.id, x.cluster))
+ val bAssignments = b.assignments.map(x => (x.id, x.cluster))
+ val unequalElements = aAssignments.join(bAssignments).filter {
+ case (id, (c1, c2)) => c1 != c2 }.count()
+ assert(unequalElements === 0L)
+ }
}