diff options
author | Yuhao Yang <hhbyyh@gmail.com> | 2016-03-31 11:12:40 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-03-31 11:12:40 -0700 |
commit | a0a1991580ed24230f88cae9f5a4dfbe58f03b28 (patch) | |
tree | d0d230f6116b8b6cd1bff0da73b1a782edb04f68 /mllib/src | |
parent | 3b3cc76004438a942ecea752db39f3a904a52462 (diff) | |
download | spark-a0a1991580ed24230f88cae9f5a4dfbe58f03b28.tar.gz spark-a0a1991580ed24230f88cae9f5a4dfbe58f03b28.tar.bz2 spark-a0a1991580ed24230f88cae9f5a4dfbe58f03b28.zip |
[SPARK-13782][ML] Model export/import for spark.ml: BisectingKMeans
## What changes were proposed in this pull request?
jira: https://issues.apache.org/jira/browse/SPARK-13782
Model export/import for BisectingKMeans in spark.ml and mllib
## How was this patch tested?
unit tests
Author: Yuhao Yang <hhbyyh@gmail.com>
Closes #11933 from hhbyyh/bisectingsave.
Diffstat (limited to 'mllib/src')
5 files changed, 190 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index f014a1d572..55f751c57f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.clustering +import org.apache.hadoop.fs.Path + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering. {BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} @@ -49,7 +51,7 @@ private[clustering] trait BisectingKMeansParams extends Params /** @group expertParam */ @Since("2.0.0") - final val minDivisibleClusterSize = new Param[Double]( + final val minDivisibleClusterSize = new DoubleParam( this, "minDivisibleClusterSize", "the minimum number of points (if >= 1.0) or the minimum proportion", @@ -81,7 +83,7 @@ private[clustering] trait BisectingKMeansParams extends Params class BisectingKMeansModel private[ml] ( @Since("2.0.0") override val uid: String, private val parentModel: MLlibBisectingKMeansModel - ) extends Model[BisectingKMeansModel] with BisectingKMeansParams { + ) extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable { @Since("2.0.0") override def copy(extra: ParamMap): BisectingKMeansModel = { @@ -115,6 +117,44 @@ class BisectingKMeansModel private[ml] ( val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } parentModel.computeCost(data) } + + @Since("2.0.0") + override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this) +} + +object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { + @Since("2.0.0") + override def read: MLReader[BisectingKMeansModel] = new BisectingKMeansModelReader + + @Since("2.0.0") + override def load(path: String): BisectingKMeansModel = super.load(path) + + /** [[MLWriter]] instance for [[BisectingKMeansModel]] */ + private[BisectingKMeansModel] + class BisectingKMeansModelWriter(instance: BisectingKMeansModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + val dataPath = new Path(path, "data").toString + instance.parentModel.save(sc, dataPath) + } + } + + private class BisectingKMeansModelReader extends MLReader[BisectingKMeansModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[BisectingKMeansModel].getName + + override def load(path: String): BisectingKMeansModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath) + val model = new BisectingKMeansModel(metadata.uid, mllibModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** @@ -137,7 +177,7 @@ class BisectingKMeansModel private[ml] ( @Experimental class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") override val uid: String) - extends Estimator[BisectingKMeansModel] with BisectingKMeansParams { + extends Estimator[BisectingKMeansModel] with BisectingKMeansParams with DefaultParamsWritable { setDefault( k -> 4, @@ -148,7 +188,7 @@ class BisectingKMeans @Since("2.0.0") ( override def copy(extra: ParamMap): BisectingKMeans = defaultCopy(extra) @Since("2.0.0") - def this() = this(Identifiable.randomUID("bisecting k-means")) + def this() = this(Identifiable.randomUID("bisecting-kmeans")) /** @group setParam */ @Since("2.0.0") @@ -194,3 +234,10 @@ class BisectingKMeans @Since("2.0.0") ( } } + +@Since("2.0.0") +object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] { + + @Since("2.0.0") + override def load(path: String): BisectingKMeans = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 64b838a1db..e4bd0dc25e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -411,7 +411,7 @@ private object BisectingKMeans extends Serializable { private[clustering] class ClusteringTreeNode private[clustering] ( val index: Int, val size: Long, - private val centerWithNorm: VectorWithNorm, + private[clustering] val centerWithNorm: VectorWithNorm, val cost: Double, val height: Double, val children: Array[ClusteringTreeNode]) extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index 01a0d31f14..c3b5b8b790 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -17,11 +17,19 @@ package org.apache.spark.mllib.clustering +import org.json4s._ +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonDSL._ + +import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} /** * Clustering model produced by [[BisectingKMeans]]. @@ -34,7 +42,7 @@ import org.apache.spark.rdd.RDD @Experimental class BisectingKMeansModel private[clustering] ( private[clustering] val root: ClusteringTreeNode - ) extends Serializable with Logging { + ) extends Serializable with Saveable with Logging { /** * Leaf cluster centers. @@ -92,4 +100,92 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def computeCost(data: JavaRDD[Vector]): Double = this.computeCost(data.rdd) + + @Since("2.0.0") + override def save(sc: SparkContext, path: String): Unit = { + BisectingKMeansModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +@Since("2.0.0") +object BisectingKMeansModel extends Loader[BisectingKMeansModel] { + + @Since("2.0.0") + override def load(sc: SparkContext, path: String): BisectingKMeansModel = { + val (loadedClassName, formatVersion, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val rootId = (metadata \ "rootId").extract[Int] + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, formatVersion) match { + case (classNameV1_0, "1.0") => + val model = SaveLoadV1_0.load(sc, path, rootId) + model + case _ => throw new Exception( + s"BisectingKMeansModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $formatVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private case class Data(index: Int, size: Long, center: Vector, norm: Double, cost: Double, + height: Double, children: Seq[Int]) + + private object Data { + def apply(r: Row): Data = Data(r.getInt(0), r.getLong(1), r.getAs[Vector](2), r.getDouble(3), + r.getDouble(4), r.getDouble(5), r.getSeq[Int](6)) + } + + private[clustering] object SaveLoadV1_0 { + private val thisFormatVersion = "1.0" + + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" + + def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) + ~ ("rootId" -> model.root.index))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val data = getNodes(model.root).map(node => Data(node.index, node.size, + node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height, + node.children.map(_.index))) + val dataRDD = sc.parallelize(data).toDF() + dataRDD.write.parquet(Loader.dataPath(path)) + } + + private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = { + if (node.children.isEmpty) { + Array(node) + } else { + node.children.flatMap(getNodes(_)) ++ Array(node) + } + } + + def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = { + val sqlContext = SQLContext.getOrCreate(sc) + val rows = sqlContext.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Data](rows.schema) + val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") + val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap + val rootNode = buildTree(rootId, nodes) + new BisectingKMeansModel(rootNode) + } + + private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = { + val root = nodes.get(rootId).get + if (root.children.isEmpty) { + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, new Array[ClusteringTreeNode](0)) + } else { + val children = root.children.map(c => buildTree(c, nodes)) + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, children.toArray) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index b719a8c7e7..18f2c994b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame -class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { +class BisectingKMeansSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 @transient var dataset: DataFrame = _ @@ -84,4 +86,22 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) } + + test("read/write") { + def checkModelData(model: BisectingKMeansModel, model2: BisectingKMeansModel): Unit = { + assert(model.clusterCenters === model2.clusterCenters) + } + val bisectingKMeans = new BisectingKMeans() + testEstimatorAndModelReadWrite( + bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData) + } +} + +object BisectingKMeansSuite { + val allParamSettings: Map[String, Any] = Map( + "k" -> 3, + "maxIter" -> 2, + "seed" -> -1L, + "minDivisibleClusterSize" -> 2.0 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala index 41b9d5c0d9..35f7932ae8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -179,4 +180,21 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("BisectingKMeans model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val points = (1 until 8).map(i => Vectors.dense(i)) + val data = sc.parallelize(points, 2) + val model = new BisectingKMeans().run(data) + try { + model.save(sc, path) + val sameModel = BisectingKMeansModel.load(sc, path) + assert(model.k === sameModel.k) + model.clusterCenters.zip(sameModel.clusterCenters).foreach(c => c._1 === c._2) + } finally { + Utils.deleteRecursively(tempDir) + } + } } |