aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala2
2 files changed, 9 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index aa78149699..df2a9c0dd5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -39,6 +39,9 @@ import org.apache.spark.sql.{Row, SparkSession}
class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector])
extends Saveable with Serializable with PMMLExportable {
+ private val clusterCentersWithNorm =
+ if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_))
+
/**
* A Java-friendly constructor that takes an Iterable of Vectors.
*/
@@ -49,7 +52,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
* Total number of clusters.
*/
@Since("0.8.0")
- def k: Int = clusterCenters.length
+ def k: Int = clusterCentersWithNorm.length
/**
* Returns the cluster index that a given point belongs to.
@@ -64,8 +67,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
*/
@Since("1.0.0")
def predict(points: RDD[Vector]): RDD[Int] = {
- val centersWithNorm = clusterCentersWithNorm
- val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
+ val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm)
points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1)
}
@@ -82,13 +84,10 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
*/
@Since("0.8.0")
def computeCost(data: RDD[Vector]): Double = {
- val centersWithNorm = clusterCentersWithNorm
- val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
+ val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm)
data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum()
}
- private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
- clusterCenters.map(new VectorWithNorm(_))
@Since("1.4.0")
override def save(sc: SparkContext, path: String): Unit = {
@@ -127,8 +126,8 @@ object KMeansModel extends Loader[KMeansModel] {
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
- val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) =>
- Cluster(id, point)
+ val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) =>
+ Cluster(id, p.vector)
}
spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 85c37c438d..3ca75e8cdb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -145,7 +145,7 @@ class StreamingKMeansModel @Since("1.2.0") (
}
}
- this
+ new StreamingKMeansModel(clusterCenters, clusterWeights)
}
}