aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-03-11 00:24:55 -0700
committerXiangrui Meng <meng@databricks.com>2015-03-11 00:24:55 -0700
commit2d4e00efe2cf179935ae108a68f28edf6e5a1628 (patch)
tree2980f1d4c8467c4204ddc519250bd007c8ebcdd7
parent2672374110d58e45ffae2408e74b96613deddda3 (diff)
downloadspark-2d4e00efe2cf179935ae108a68f28edf6e5a1628.tar.gz
spark-2d4e00efe2cf179935ae108a68f28edf6e5a1628.tar.bz2
spark-2d4e00efe2cf179935ae108a68f28edf6e5a1628.zip
[SPARK-5986][MLLib] Add save/load for k-means
This PR adds save/load for K-means as described in SPARK-5986. Python version will be added in another PR. Author: Xusen Yin <yinxusen@gmail.com> Closes #4951 from yinxusen/SPARK-5986 and squashes the following commits: 6dd74a0 [Xusen Yin] rewrite some functions and classes cd390fd [Xusen Yin] add indexed point b144216 [Xusen Yin] remove invalid comments dce7055 [Xusen Yin] add save/load for k-means for SPARK-5986
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala68
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala44
2 files changed, 108 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index 3b95a9e693..707da537d2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -17,15 +17,22 @@
package org.apache.spark.mllib.clustering
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.{Loader, Saveable}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.Row
/**
* A clustering model for K-means. Each point belongs to the cluster with the closest center.
*/
-class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {
+class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable {
/** Total number of clusters. */
def k: Int = clusterCenters.length
@@ -58,4 +65,59 @@ class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {
private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
clusterCenters.map(new VectorWithNorm(_))
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ KMeansModel.SaveLoadV1_0.save(sc, this, path)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object KMeansModel extends Loader[KMeansModel] {
+ override def load(sc: SparkContext, path: String): KMeansModel = {
+ KMeansModel.SaveLoadV1_0.load(sc, path)
+ }
+
+ private case class Cluster(id: Int, point: Vector)
+
+ private object Cluster {
+ def apply(r: Row): Cluster = {
+ Cluster(r.getInt(0), r.getAs[Vector](1))
+ }
+ }
+
+ private[clustering]
+ object SaveLoadV1_0 {
+
+ private val thisFormatVersion = "1.0"
+
+ private[clustering]
+ val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"
+
+ def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+ val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) =>
+ Cluster(id, point)
+ }.toDF()
+ dataRDD.saveAsParquetFile(Loader.dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): KMeansModel = {
+ implicit val formats = DefaultFormats
+ val sqlContext = new SQLContext(sc)
+ val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+ assert(className == thisClassName)
+ assert(formatVersion == thisFormatVersion)
+ val k = (metadata \ "k").extract[Int]
+ val centriods = sqlContext.parquetFile(Loader.dataPath(path))
+ Loader.checkSchema[Cluster](centriods.schema)
+ val localCentriods = centriods.map(Cluster.apply).collect()
+ assert(k == localCentriods.size)
+ new KMeansModel(localCentriods.sortBy(_.id).map(_.point))
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index caee591700..7bf250eb5a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -21,9 +21,10 @@ import scala.util.Random
import org.scalatest.FunSuite
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
class KMeansSuite extends FunSuite with MLlibTestSparkContext {
@@ -257,6 +258,47 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
assert(predicts(0) != predicts(3))
}
}
+
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ Array(true, false).foreach { case selector =>
+ val model = KMeansSuite.createModel(10, 3, selector)
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = KMeansModel.load(sc, path)
+ KMeansSuite.checkEqual(model, sameModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ }
+}
+
+object KMeansSuite extends FunSuite {
+ def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = {
+ val singlePoint = isSparse match {
+ case true =>
+ Vectors.sparse(dim, Array.empty[Int], Array.empty[Double])
+ case _ =>
+ Vectors.dense(Array.fill[Double](dim)(0.0))
+ }
+ new KMeansModel(Array.fill[Vector](k)(singlePoint))
+ }
+
+ def checkEqual(a: KMeansModel, b: KMeansModel): Unit = {
+ assert(a.k === b.k)
+ a.clusterCenters.zip(b.clusterCenters).foreach {
+ case (ca: SparseVector, cb: SparseVector) =>
+ assert(ca === cb)
+ case (ca: DenseVector, cb: DenseVector) =>
+ assert(ca === cb)
+ case _ =>
+ throw new AssertionError("checkEqual failed since the two clusters were not identical.\n")
+ }
+ }
}
class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext {