aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2016-03-31 11:12:40 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-31 11:12:40 -0700
commita0a1991580ed24230f88cae9f5a4dfbe58f03b28 (patch)
treed0d230f6116b8b6cd1bff0da73b1a782edb04f68
parent3b3cc76004438a942ecea752db39f3a904a52462 (diff)
downloadspark-a0a1991580ed24230f88cae9f5a4dfbe58f03b28.tar.gz
spark-a0a1991580ed24230f88cae9f5a4dfbe58f03b28.tar.bz2
spark-a0a1991580ed24230f88cae9f5a4dfbe58f03b28.zip
[SPARK-13782][ML] Model export/import for spark.ml: BisectingKMeans
## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-13782 Model export/import for BisectingKMeans in spark.ml and mllib ## How was this patch tested? unit tests Author: Yuhao Yang <hhbyyh@gmail.com> Closes #11933 from hhbyyh/bisectingsave.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala59
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala98
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala22
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala18
5 files changed, 190 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index f014a1d572..55f751c57f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -17,11 +17,13 @@
package org.apache.spark.ml.clustering
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
+import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.
{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
@@ -49,7 +51,7 @@ private[clustering] trait BisectingKMeansParams extends Params
/** @group expertParam */
@Since("2.0.0")
- final val minDivisibleClusterSize = new Param[Double](
+ final val minDivisibleClusterSize = new DoubleParam(
this,
"minDivisibleClusterSize",
"the minimum number of points (if >= 1.0) or the minimum proportion",
@@ -81,7 +83,7 @@ private[clustering] trait BisectingKMeansParams extends Params
class BisectingKMeansModel private[ml] (
@Since("2.0.0") override val uid: String,
private val parentModel: MLlibBisectingKMeansModel
- ) extends Model[BisectingKMeansModel] with BisectingKMeansParams {
+ ) extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable {
@Since("2.0.0")
override def copy(extra: ParamMap): BisectingKMeansModel = {
@@ -115,6 +117,44 @@ class BisectingKMeansModel private[ml] (
val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
parentModel.computeCost(data)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this)
+}
+
+object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
+ @Since("2.0.0")
+ override def read: MLReader[BisectingKMeansModel] = new BisectingKMeansModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): BisectingKMeansModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[BisectingKMeansModel]] */
+ private[BisectingKMeansModel]
+ class BisectingKMeansModelWriter(instance: BisectingKMeansModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val dataPath = new Path(path, "data").toString
+ instance.parentModel.save(sc, dataPath)
+ }
+ }
+
+ private class BisectingKMeansModelReader extends MLReader[BisectingKMeansModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[BisectingKMeansModel].getName
+
+ override def load(path: String): BisectingKMeansModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath)
+ val model = new BisectingKMeansModel(metadata.uid, mllibModel)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
}
/**
@@ -137,7 +177,7 @@ class BisectingKMeansModel private[ml] (
@Experimental
class BisectingKMeans @Since("2.0.0") (
@Since("2.0.0") override val uid: String)
- extends Estimator[BisectingKMeansModel] with BisectingKMeansParams {
+ extends Estimator[BisectingKMeansModel] with BisectingKMeansParams with DefaultParamsWritable {
setDefault(
k -> 4,
@@ -148,7 +188,7 @@ class BisectingKMeans @Since("2.0.0") (
override def copy(extra: ParamMap): BisectingKMeans = defaultCopy(extra)
@Since("2.0.0")
- def this() = this(Identifiable.randomUID("bisecting k-means"))
+ def this() = this(Identifiable.randomUID("bisecting-kmeans"))
/** @group setParam */
@Since("2.0.0")
@@ -194,3 +234,10 @@ class BisectingKMeans @Since("2.0.0") (
}
}
+
+@Since("2.0.0")
+object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] {
+
+ @Since("2.0.0")
+ override def load(path: String): BisectingKMeans = super.load(path)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
index 64b838a1db..e4bd0dc25e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
@@ -411,7 +411,7 @@ private object BisectingKMeans extends Serializable {
private[clustering] class ClusteringTreeNode private[clustering] (
val index: Int,
val size: Long,
- private val centerWithNorm: VectorWithNorm,
+ private[clustering] val centerWithNorm: VectorWithNorm,
val cost: Double,
val height: Double,
val children: Array[ClusteringTreeNode]) extends Serializable {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
index 01a0d31f14..c3b5b8b790 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
@@ -17,11 +17,19 @@
package org.apache.spark.mllib.clustering
+import org.json4s._
+import org.json4s.DefaultFormats
+import org.json4s.jackson.JsonMethods._
+import org.json4s.JsonDSL._
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
/**
* Clustering model produced by [[BisectingKMeans]].
@@ -34,7 +42,7 @@ import org.apache.spark.rdd.RDD
@Experimental
class BisectingKMeansModel private[clustering] (
private[clustering] val root: ClusteringTreeNode
- ) extends Serializable with Logging {
+ ) extends Serializable with Saveable with Logging {
/**
* Leaf cluster centers.
@@ -92,4 +100,92 @@ class BisectingKMeansModel private[clustering] (
*/
@Since("1.6.0")
def computeCost(data: JavaRDD[Vector]): Double = this.computeCost(data.rdd)
+
+ @Since("2.0.0")
+ override def save(sc: SparkContext, path: String): Unit = {
+ BisectingKMeansModel.SaveLoadV1_0.save(sc, this, path)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+@Since("2.0.0")
+object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
+
+ @Since("2.0.0")
+ override def load(sc: SparkContext, path: String): BisectingKMeansModel = {
+ val (loadedClassName, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+ implicit val formats = DefaultFormats
+ val rootId = (metadata \ "rootId").extract[Int]
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+ (loadedClassName, formatVersion) match {
+ case (classNameV1_0, "1.0") =>
+ val model = SaveLoadV1_0.load(sc, path, rootId)
+ model
+ case _ => throw new Exception(
+ s"BisectingKMeansModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $formatVersion). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
+
+ private case class Data(index: Int, size: Long, center: Vector, norm: Double, cost: Double,
+ height: Double, children: Seq[Int])
+
+ private object Data {
+ def apply(r: Row): Data = Data(r.getInt(0), r.getLong(1), r.getAs[Vector](2), r.getDouble(3),
+ r.getDouble(4), r.getDouble(5), r.getSeq[Int](6))
+ }
+
+ private[clustering] object SaveLoadV1_0 {
+ private val thisFormatVersion = "1.0"
+
+ private[clustering]
+ val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel"
+
+ def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ import sqlContext.implicits._
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)
+ ~ ("rootId" -> model.root.index)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ val data = getNodes(model.root).map(node => Data(node.index, node.size,
+ node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height,
+ node.children.map(_.index)))
+ val dataRDD = sc.parallelize(data).toDF()
+ dataRDD.write.parquet(Loader.dataPath(path))
+ }
+
+ private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = {
+ if (node.children.isEmpty) {
+ Array(node)
+ } else {
+ node.children.flatMap(getNodes(_)) ++ Array(node)
+ }
+ }
+
+ def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ val rows = sqlContext.read.parquet(Loader.dataPath(path))
+ Loader.checkSchema[Data](rows.schema)
+ val data = rows.select("index", "size", "center", "norm", "cost", "height", "children")
+ val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap
+ val rootNode = buildTree(rootId, nodes)
+ new BisectingKMeansModel(rootNode)
+ }
+
+ private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = {
+ val root = nodes.get(rootId).get
+ if (root.children.isEmpty) {
+ new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
+ root.cost, root.height, new Array[ClusteringTreeNode](0))
+ } else {
+ val children = root.children.map(c => buildTree(c, nodes))
+ new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
+ root.cost, root.height, children.toArray)
+ }
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index b719a8c7e7..18f2c994b4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -18,10 +18,12 @@
package org.apache.spark.ml.clustering
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
-class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
+class BisectingKMeansSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
final val k = 5
@transient var dataset: DataFrame = _
@@ -84,4 +86,22 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
}
+
+ test("read/write") {
+ def checkModelData(model: BisectingKMeansModel, model2: BisectingKMeansModel): Unit = {
+ assert(model.clusterCenters === model2.clusterCenters)
+ }
+ val bisectingKMeans = new BisectingKMeans()
+ testEstimatorAndModelReadWrite(
+ bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData)
+ }
+}
+
+object BisectingKMeansSuite {
+ val allParamSettings: Map[String, Any] = Map(
+ "k" -> 3,
+ "maxIter" -> 2,
+ "seed" -> -1L,
+ "minDivisibleClusterSize" -> 2.0
+ )
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala
index 41b9d5c0d9..35f7932ae8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -179,4 +180,21 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
}
+
+ test("BisectingKMeans model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val points = (1 until 8).map(i => Vectors.dense(i))
+ val data = sc.parallelize(points, 2)
+ val model = new BisectingKMeans().run(data)
+ try {
+ model.save(sc, path)
+ val sameModel = BisectingKMeansModel.load(sc, path)
+ assert(model.k === sameModel.k)
+ model.clusterCenters.zip(sameModel.clusterCenters).foreach(c => c._1 === c._2)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}