From 445d4d9e13ebaee9eceea6135fe7ee47812d97de Mon Sep 17 00:00:00 2001 From: Zakaria_Hili Date: Fri, 25 Nov 2016 13:19:26 +0000 Subject: [SPARK-18356][ML] Improve MLKmeans Performance ## What changes were proposed in this pull request? Spark Kmeans fit() doesn't cache the RDD which generates a lot of warnings : WARN KMeans: The input data is not directly cached, which may hurt performance if its parent RDDs are also uncached. So, Kmeans should cache the internal rdd before calling the Mllib.Kmeans algo, this helped to improve spark kmeans performance by 14% https://github.com/ZakariaHili/spark/commit/a9cf905cf7dbd50eeb9a8b4f891f2f41ea672472 hhbyyh ## How was this patch tested? Pass Kmeans tests and existing tests Author: Zakaria_Hili Author: HILI Zakaria Closes #15965 from ZakariaHili/zakbranch. --- .../org/apache/spark/ml/clustering/KMeans.scala | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 6e124eb6dd..ad4f79a79c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils.majorVersion /** @@ -305,12 +306,20 @@ class KMeans @Since("1.5.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + fit(dataset, handlePersistence) + } + + @Since("2.2.0") + protected def fit(dataset: Dataset[_], handlePersistence: Boolean): KMeansModel = { transformSchema(dataset.schema, logging = true) - val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { + val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } - - val instr = Instrumentation.create(this, rdd) + if (handlePersistence) { + instances.persist(StorageLevel.MEMORY_AND_DISK) + } + val instr = Instrumentation.create(this, instances) instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol) val algo = new MLlibKMeans() @@ -320,12 +329,15 @@ class KMeans @Since("1.5.0") ( .setMaxIterations($(maxIter)) .setSeed($(seed)) .setEpsilon($(tol)) - val parentModel = algo.run(rdd, Option(instr)) + val parentModel = algo.run(instances, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) instr.logSuccess(model) + if (handlePersistence) { + instances.unpersist() + } model } -- cgit v1.2.3