aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/Predictor.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Predictor.scala23
1 files changed, 11 insertions, 12 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index ebe48700f8..81140d1f7b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
@@ -36,6 +36,7 @@ private[ml] trait PredictorParams extends Params
/**
* Validates and transforms the input schema with the provided param map.
+ *
* @param schema input schema
* @param fitting whether this is in fitting
* @param featuresDataType SQL DataType for FeaturesType.
@@ -49,8 +50,7 @@ private[ml] trait PredictorParams extends Params
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
if (fitting) {
- // TODO: Allow other numeric types
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
}
SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
}
@@ -83,7 +83,7 @@ abstract class Predictor[
/** @group setParam */
def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
- override def fit(dataset: DataFrame): M = {
+ override def fit(dataset: Dataset[_]): M = {
// This handles a few items such as schema validation.
// Developers only need to implement train().
transformSchema(dataset.schema, logging = true)
@@ -100,7 +100,7 @@ abstract class Predictor[
* @param dataset Training dataset
* @return Fitted model
*/
- protected def train(dataset: DataFrame): M
+ protected def train(dataset: Dataset[_]): M
/**
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
@@ -120,10 +120,9 @@ abstract class Predictor[
* Extract [[labelCol]] and [[featuresCol]] from the given dataset,
* and put it in an RDD with strong types.
*/
- protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = {
- dataset.select($(labelCol), $(featuresCol)).rdd.map {
- case Row(label: Double, features: Vector) =>
- LabeledPoint(label, features)
+ protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
+ dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+ case Row(label: Double, features: Vector) => LabeledPoint(label, features)
}
}
}
@@ -172,18 +171,18 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
* @param dataset input dataset
* @return transformed dataset with [[predictionCol]] of type [[Double]]
*/
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
if ($(predictionCol).nonEmpty) {
transformImpl(dataset)
} else {
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
" since no output columns were set.")
- dataset
+ dataset.toDF
}
}
- protected def transformImpl(dataset: DataFrame): DataFrame = {
+ protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Any) =>
predict(features.asInstanceOf[FeaturesType])
}