diff options
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 6e124eb6dd..ad4f79a79c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils.majorVersion /** @@ -305,12 +306,20 @@ class KMeans @Since("1.5.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + fit(dataset, handlePersistence) + } + + @Since("2.2.0") + protected def fit(dataset: Dataset[_], handlePersistence: Boolean): KMeansModel = { transformSchema(dataset.schema, logging = true) - val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { + val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } - - val instr = Instrumentation.create(this, rdd) + if (handlePersistence) { + instances.persist(StorageLevel.MEMORY_AND_DISK) + } + val instr = Instrumentation.create(this, instances) instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol) val algo = new MLlibKMeans() @@ -320,12 +329,15 @@ class KMeans @Since("1.5.0") ( .setMaxIterations($(maxIter)) .setSeed($(seed)) .setEpsilon($(tol)) - val parentModel = algo.run(rdd, Option(instr)) + val parentModel = algo.run(instances, Option(instr)) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) instr.logSuccess(model) + if (handlePersistence) { + instances.unpersist() + } model } |