diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 55f751c57f..6cc9117da3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering. {BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} @@ -92,7 +92,7 @@ class BisectingKMeansModel private[ml] ( } @Since("2.0.0") - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -112,7 +112,7 @@ class BisectingKMeansModel private[ml] ( * centers. */ @Since("2.0.0") - def computeCost(dataset: DataFrame): Double = { + def computeCost(dataset: Dataset[_]): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } parentModel.computeCost(data) @@ -215,7 +215,7 @@ class BisectingKMeans @Since("2.0.0") ( def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value) @Since("2.0.0") - override def fit(dataset: DataFrame): BisectingKMeansModel = { + override def fit(dataset: Dataset[_]): BisectingKMeansModel = { val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } val bkm = new MLlibBisectingKMeans() |