From 1e340c3ae4d5361d048a3d6990f144cfc923666f Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 13 Apr 2015 11:53:17 -0700 Subject: [SPARK-5988][MLlib] add save/load for PowerIterationClusteringModel See JIRA issue [SPARK-5988](https://issues.apache.org/jira/browse/SPARK-5988). Author: Xusen Yin Closes #5450 from yinxusen/SPARK-5988 and squashes the following commits: cb1ecfa [Xusen Yin] change Assignment into case class b1dd24c [Xusen Yin] add test suite 63c3923 [Xusen Yin] add save load for power iteration clustering --- .../clustering/PowerIterationClustering.scala | 68 ++++++++++++++++++++-- .../clustering/PowerIterationClusteringSuite.scala | 34 +++++++++++ 2 files changed, 97 insertions(+), 5 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 180023922a..aa53e88d59 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -17,15 +17,20 @@ package org.apache.spark.mllib.clustering -import org.apache.spark.{Logging, SparkException} +import org.json4s.JsonDSL._ +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.{Logging, SparkContext, SparkException} /** * :: Experimental :: @@ -38,7 +43,60 @@ import org.apache.spark.util.random.XORShiftRandom @Experimental class PowerIterationClusteringModel( val k: Int, - val assignments: RDD[PowerIterationClustering.Assignment]) extends Serializable + val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable { + + override def save(sc: SparkContext, path: String): Unit = { + PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] { + override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { + PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path) + } + + private[clustering] + object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel" + + def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val dataRDD = model.assignments.toDF() + dataRDD.saveAsParquetFile(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { + implicit val formats = DefaultFormats + val sqlContext = new SQLContext(sc) + + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val k = (metadata \ "k").extract[Int] + val assignments = sqlContext.parquetFile(Loader.dataPath(path)) + Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema) + + val assignmentsRDD = assignments.map { + case Row(id: Long, cluster: Int) => PowerIterationClustering.Assignment(id, cluster) + } + + new PowerIterationClusteringModel(k, assignmentsRDD) + } + } +} /** * :: Experimental :: @@ -135,7 +193,7 @@ class PowerIterationClustering private[clustering] ( val v = powerIter(w, maxIterations) val assignments = kMeans(v, k).mapPartitions({ iter => iter.map { case (id, cluster) => - new Assignment(id, cluster) + Assignment(id, cluster) } }, preservesPartitioning = true) new PowerIterationClusteringModel(k, assignments) @@ -152,7 +210,7 @@ object PowerIterationClustering extends Logging { * @param cluster assigned cluster id */ @Experimental - class Assignment(val id: Long, val cluster: Int) extends Serializable + case class Assignment(id: Long, cluster: Int) /** * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W). diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 6315c03a70..6d6fe6fe46 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable +import scala.util.Random import org.scalatest.FunSuite +import org.apache.spark.SparkContext import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext { @@ -110,4 +113,35 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext assert(x ~== u1(i.toInt) absTol 1e-14) } } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val model = PowerIterationClusteringSuite.createModel(sc, 3, 10) + try { + model.save(sc, path) + val sameModel = PowerIterationClusteringModel.load(sc, path) + PowerIterationClusteringSuite.checkEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } +} + +object PowerIterationClusteringSuite extends FunSuite { + def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = { + val assignments = sc.parallelize( + (0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k)))) + new PowerIterationClusteringModel(k, assignments) + } + + def checkEqual(a: PowerIterationClusteringModel, b: PowerIterationClusteringModel): Unit = { + assert(a.k === b.k) + + val aAssignments = a.assignments.map(x => (x.id, x.cluster)) + val bAssignments = b.assignments.map(x => (x.id, x.cluster)) + val unequalElements = aAssignments.join(bAssignments).filter { + case (id, (c1, c2)) => c1 != c2 }.count() + assert(unequalElements === 0L) + } } -- cgit v1.2.3