aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorZakaria_Hili <zakahili@gmail.com>2016-11-25 13:19:26 +0000
committerSean Owen <sowen@cloudera.com>2016-11-25 13:19:26 +0000
commit445d4d9e13ebaee9eceea6135fe7ee47812d97de (patch)
tree6daa2073a398565a33c59274108c40522123bc96 /mllib
parent5ecdc7c5c019acc6b1f9c2e6c5b7d35957eadb88 (diff)
downloadspark-445d4d9e13ebaee9eceea6135fe7ee47812d97de.tar.gz
spark-445d4d9e13ebaee9eceea6135fe7ee47812d97de.tar.bz2
spark-445d4d9e13ebaee9eceea6135fe7ee47812d97de.zip
[SPARK-18356][ML] Improve MLKmeans Performance
## What changes were proposed in this pull request? Spark Kmeans fit() doesn't cache the RDD which generates a lot of warnings : WARN KMeans: The input data is not directly cached, which may hurt performance if its parent RDDs are also uncached. So, Kmeans should cache the internal rdd before calling the Mllib.Kmeans algo, this helped to improve spark kmeans performance by 14% https://github.com/ZakariaHili/spark/commit/a9cf905cf7dbd50eeb9a8b4f891f2f41ea672472 hhbyyh ## How was this patch tested? Pass Kmeans tests and existing tests Author: Zakaria_Hili <zakahili@gmail.com> Author: HILI Zakaria <zakahili@gmail.com> Closes #15965 from ZakariaHili/zakbranch.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala20
1 files changed, 16 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 6e124eb6dd..ad4f79a79c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils.majorVersion
/**
@@ -305,12 +306,20 @@ class KMeans @Since("1.5.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): KMeansModel = {
+ val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ fit(dataset, handlePersistence)
+ }
+
+ @Since("2.2.0")
+ protected def fit(dataset: Dataset[_], handlePersistence: Boolean): KMeansModel = {
transformSchema(dataset.schema, logging = true)
- val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
+ val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}
-
- val instr = Instrumentation.create(this, rdd)
+ if (handlePersistence) {
+ instances.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+ val instr = Instrumentation.create(this, instances)
instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol)
val algo = new MLlibKMeans()
@@ -320,12 +329,15 @@ class KMeans @Since("1.5.0") (
.setMaxIterations($(maxIter))
.setSeed($(seed))
.setEpsilon($(tol))
- val parentModel = algo.run(rdd, Option(instr))
+ val parentModel = algo.run(instances, Option(instr))
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val summary = new KMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(Some(summary))
instr.logSuccess(model)
+ if (handlePersistence) {
+ instances.unpersist()
+ }
model
}