diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 865614aa5c..d00fee12b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors, VectorUDT} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} @@ -95,7 +95,7 @@ abstract class ProbabilisticClassificationModel[ * @param dataset input dataset * @return transformed dataset */ - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -145,7 +145,7 @@ abstract class ProbabilisticClassificationModel[ this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + " since no output columns were set.") } - outputData + outputData.toDF } /** |