aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala8
1 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 55f751c57f..6cc9117da3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.
{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
@@ -92,7 +92,7 @@ class BisectingKMeansModel private[ml] (
}
@Since("2.0.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf((vector: Vector) => predict(vector))
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
@@ -112,7 +112,7 @@ class BisectingKMeansModel private[ml] (
* centers.
*/
@Since("2.0.0")
- def computeCost(dataset: DataFrame): Double = {
+ def computeCost(dataset: Dataset[_]): Double = {
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
parentModel.computeCost(data)
@@ -215,7 +215,7 @@ class BisectingKMeans @Since("2.0.0") (
def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value)
@Since("2.0.0")
- override def fit(dataset: DataFrame): BisectingKMeansModel = {
+ override def fit(dataset: Dataset[_]): BisectingKMeansModel = {
val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
val bkm = new MLlibBisectingKMeans()