aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorSean Owen <sowen@cloudera.com>2016-12-30 10:40:17 +0000
committerSean Owen <sowen@cloudera.com>2016-12-30 10:40:17 +0000
commit56d3a7eb83f9c91d06dab2c91e10569723eeb105 (patch)
tree6f6ff01f72b0feb70c654fb52289ddfe06effc03 /mllib/src
parent63036aee2271cdbb7032b51b2ac67edbcb82389e (diff)
downloadspark-56d3a7eb83f9c91d06dab2c91e10569723eeb105.tar.gz
spark-56d3a7eb83f9c91d06dab2c91e10569723eeb105.tar.bz2
spark-56d3a7eb83f9c91d06dab2c91e10569723eeb105.zip
[SPARK-18808][ML][MLLIB] ml.KMeansModel.transform is very inefficient
## What changes were proposed in this pull request? mllib.KMeansModel.clusterCentersWithNorm is a method than ends up being called every time `predict` is called on a single vector, which is bad news for now the ml.KMeansModel Transformer works, which necessarily transforms one vector at a time. This causes the model to just store the vectors with norms upfront. The extra norm should be small compared to the vectors. This would avoid this form of overhead on this and other code paths. ## How was this patch tested? Existing tests. Author: Sean Owen <sowen@cloudera.com> Closes #16328 from srowen/SPARK-18808.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala2
2 files changed, 9 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index aa78149699..df2a9c0dd5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -39,6 +39,9 @@ import org.apache.spark.sql.{Row, SparkSession}
class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector])
extends Saveable with Serializable with PMMLExportable {
+ private val clusterCentersWithNorm =
+ if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_))
+
/**
* A Java-friendly constructor that takes an Iterable of Vectors.
*/
@@ -49,7 +52,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
* Total number of clusters.
*/
@Since("0.8.0")
- def k: Int = clusterCenters.length
+ def k: Int = clusterCentersWithNorm.length
/**
* Returns the cluster index that a given point belongs to.
@@ -64,8 +67,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
*/
@Since("1.0.0")
def predict(points: RDD[Vector]): RDD[Int] = {
- val centersWithNorm = clusterCentersWithNorm
- val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
+ val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm)
points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1)
}
@@ -82,13 +84,10 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
*/
@Since("0.8.0")
def computeCost(data: RDD[Vector]): Double = {
- val centersWithNorm = clusterCentersWithNorm
- val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
+ val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm)
data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum()
}
- private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
- clusterCenters.map(new VectorWithNorm(_))
@Since("1.4.0")
override def save(sc: SparkContext, path: String): Unit = {
@@ -127,8 +126,8 @@ object KMeansModel extends Loader[KMeansModel] {
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
- val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) =>
- Cluster(id, point)
+ val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) =>
+ Cluster(id, p.vector)
}
spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 85c37c438d..3ca75e8cdb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -145,7 +145,7 @@ class StreamingKMeansModel @Since("1.2.0") (
}
}
- this
+ new StreamingKMeansModel(clusterCenters, clusterWeights)
}
}