From a0a1991580ed24230f88cae9f5a4dfbe58f03b28 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 31 Mar 2016 11:12:40 -0700 Subject: [SPARK-13782][ML] Model export/import for spark.ml: BisectingKMeans ## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-13782 Model export/import for BisectingKMeans in spark.ml and mllib ## How was this patch tested? unit tests Author: Yuhao Yang Closes #11933 from hhbyyh/bisectingsave. --- .../spark/ml/clustering/BisectingKMeans.scala | 59 +++++++++++-- .../spark/mllib/clustering/BisectingKMeans.scala | 2 +- .../mllib/clustering/BisectingKMeansModel.scala | 98 +++++++++++++++++++++- 3 files changed, 151 insertions(+), 8 deletions(-) (limited to 'mllib/src/main/scala') diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index f014a1d572..55f751c57f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.clustering +import org.apache.hadoop.fs.Path + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering. {BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} @@ -49,7 +51,7 @@ private[clustering] trait BisectingKMeansParams extends Params /** @group expertParam */ @Since("2.0.0") - final val minDivisibleClusterSize = new Param[Double]( + final val minDivisibleClusterSize = new DoubleParam( this, "minDivisibleClusterSize", "the minimum number of points (if >= 1.0) or the minimum proportion", @@ -81,7 +83,7 @@ private[clustering] trait BisectingKMeansParams extends Params class BisectingKMeansModel private[ml] ( @Since("2.0.0") override val uid: String, private val parentModel: MLlibBisectingKMeansModel - ) extends Model[BisectingKMeansModel] with BisectingKMeansParams { + ) extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable { @Since("2.0.0") override def copy(extra: ParamMap): BisectingKMeansModel = { @@ -115,6 +117,44 @@ class BisectingKMeansModel private[ml] ( val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } parentModel.computeCost(data) } + + @Since("2.0.0") + override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this) +} + +object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { + @Since("2.0.0") + override def read: MLReader[BisectingKMeansModel] = new BisectingKMeansModelReader + + @Since("2.0.0") + override def load(path: String): BisectingKMeansModel = super.load(path) + + /** [[MLWriter]] instance for [[BisectingKMeansModel]] */ + private[BisectingKMeansModel] + class BisectingKMeansModelWriter(instance: BisectingKMeansModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + val dataPath = new Path(path, "data").toString + instance.parentModel.save(sc, dataPath) + } + } + + private class BisectingKMeansModelReader extends MLReader[BisectingKMeansModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[BisectingKMeansModel].getName + + override def load(path: String): BisectingKMeansModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath) + val model = new BisectingKMeansModel(metadata.uid, mllibModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** @@ -137,7 +177,7 @@ class BisectingKMeansModel private[ml] ( @Experimental class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") override val uid: String) - extends Estimator[BisectingKMeansModel] with BisectingKMeansParams { + extends Estimator[BisectingKMeansModel] with BisectingKMeansParams with DefaultParamsWritable { setDefault( k -> 4, @@ -148,7 +188,7 @@ class BisectingKMeans @Since("2.0.0") ( override def copy(extra: ParamMap): BisectingKMeans = defaultCopy(extra) @Since("2.0.0") - def this() = this(Identifiable.randomUID("bisecting k-means")) + def this() = this(Identifiable.randomUID("bisecting-kmeans")) /** @group setParam */ @Since("2.0.0") @@ -194,3 +234,10 @@ class BisectingKMeans @Since("2.0.0") ( } } + +@Since("2.0.0") +object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] { + + @Since("2.0.0") + override def load(path: String): BisectingKMeans = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 64b838a1db..e4bd0dc25e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -411,7 +411,7 @@ private object BisectingKMeans extends Serializable { private[clustering] class ClusteringTreeNode private[clustering] ( val index: Int, val size: Long, - private val centerWithNorm: VectorWithNorm, + private[clustering] val centerWithNorm: VectorWithNorm, val cost: Double, val height: Double, val children: Array[ClusteringTreeNode]) extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index 01a0d31f14..c3b5b8b790 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -17,11 +17,19 @@ package org.apache.spark.mllib.clustering +import org.json4s._ +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonDSL._ + +import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} /** * Clustering model produced by [[BisectingKMeans]]. @@ -34,7 +42,7 @@ import org.apache.spark.rdd.RDD @Experimental class BisectingKMeansModel private[clustering] ( private[clustering] val root: ClusteringTreeNode - ) extends Serializable with Logging { + ) extends Serializable with Saveable with Logging { /** * Leaf cluster centers. @@ -92,4 +100,92 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def computeCost(data: JavaRDD[Vector]): Double = this.computeCost(data.rdd) + + @Since("2.0.0") + override def save(sc: SparkContext, path: String): Unit = { + BisectingKMeansModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +@Since("2.0.0") +object BisectingKMeansModel extends Loader[BisectingKMeansModel] { + + @Since("2.0.0") + override def load(sc: SparkContext, path: String): BisectingKMeansModel = { + val (loadedClassName, formatVersion, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val rootId = (metadata \ "rootId").extract[Int] + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, formatVersion) match { + case (classNameV1_0, "1.0") => + val model = SaveLoadV1_0.load(sc, path, rootId) + model + case _ => throw new Exception( + s"BisectingKMeansModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $formatVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private case class Data(index: Int, size: Long, center: Vector, norm: Double, cost: Double, + height: Double, children: Seq[Int]) + + private object Data { + def apply(r: Row): Data = Data(r.getInt(0), r.getLong(1), r.getAs[Vector](2), r.getDouble(3), + r.getDouble(4), r.getDouble(5), r.getSeq[Int](6)) + } + + private[clustering] object SaveLoadV1_0 { + private val thisFormatVersion = "1.0" + + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" + + def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) + ~ ("rootId" -> model.root.index))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val data = getNodes(model.root).map(node => Data(node.index, node.size, + node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height, + node.children.map(_.index))) + val dataRDD = sc.parallelize(data).toDF() + dataRDD.write.parquet(Loader.dataPath(path)) + } + + private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = { + if (node.children.isEmpty) { + Array(node) + } else { + node.children.flatMap(getNodes(_)) ++ Array(node) + } + } + + def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = { + val sqlContext = SQLContext.getOrCreate(sc) + val rows = sqlContext.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Data](rows.schema) + val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") + val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap + val rootNode = buildTree(rootId, nodes) + new BisectingKMeansModel(rootNode) + } + + private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = { + val root = nodes.get(rootId).get + if (root.children.isEmpty) { + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, new Array[ClusteringTreeNode](0)) + } else { + val children = root.children.map(c => buildTree(c, nodes)) + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, children.toArray) + } + } + } } -- cgit v1.2.3