aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala10
1 files changed, 5 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 790ef1fe8d..6f63d04818 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -211,7 +211,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
Data(idx, center)
}
val dataPath = new Path(path, "data").toString
- sqlContext.createDataFrame(data).repartition(1).write.parquet(dataPath)
+ sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath)
}
}
@@ -222,8 +222,8 @@ object KMeansModel extends MLReadable[KMeansModel] {
override def load(path: String): KMeansModel = {
// Import implicits for Dataset Encoder
- val sqlContext = super.sqlContext
- import sqlContext.implicits._
+ val sparkSession = super.sparkSession
+ import sparkSession.implicits._
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
@@ -232,11 +232,11 @@ object KMeansModel extends MLReadable[KMeansModel] {
val versionRegex(major, _) = metadata.sparkVersion
val clusterCenters = if (major.toInt >= 2) {
- val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data]
+ val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
} else {
// Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.
- sqlContext.read.parquet(dataPath).as[OldData].head().clusterCenters
+ sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters
}
val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
DefaultParamsReader.getAndSetParams(model, metadata)