aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala20
1 files changed, 16 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 6e124eb6dd..ad4f79a79c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils.majorVersion
/**
@@ -305,12 +306,20 @@ class KMeans @Since("1.5.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): KMeansModel = {
+ val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ fit(dataset, handlePersistence)
+ }
+
+ @Since("2.2.0")
+ protected def fit(dataset: Dataset[_], handlePersistence: Boolean): KMeansModel = {
transformSchema(dataset.schema, logging = true)
- val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
+ val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}
-
- val instr = Instrumentation.create(this, rdd)
+ if (handlePersistence) {
+ instances.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+ val instr = Instrumentation.create(this, instances)
instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol)
val algo = new MLlibKMeans()
@@ -320,12 +329,15 @@ class KMeans @Since("1.5.0") (
.setMaxIterations($(maxIter))
.setSeed($(seed))
.setEpsilon($(tol))
- val parentModel = algo.run(rdd, Option(instr))
+ val parentModel = algo.run(instances, Option(instr))
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val summary = new KMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(Some(summary))
instr.logSuccess(model)
+ if (handlePersistence) {
+ instances.unpersist()
+ }
model
}