aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala18
1 files changed, 10 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 12a76dbbfb..3ac6c77669 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -29,7 +29,7 @@ import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.VectorUDT
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types._
/**
@@ -103,7 +103,8 @@ class RFormula(override val uid: String)
RFormulaParser.parse($(formula)).hasIntercept
}
- override def fit(dataset: DataFrame): RFormulaModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): RFormulaModel = {
require(isDefined(formula), "Formula must be defined first.")
val parsedFormula = RFormulaParser.parse($(formula))
val resolvedFormula = parsedFormula.resolve(dataset.schema)
@@ -204,7 +205,8 @@ class RFormulaModel private[feature](
private[ml] val pipelineModel: PipelineModel)
extends Model[RFormulaModel] with RFormulaBase with MLWritable {
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
checkCanTransform(dataset.schema)
transformLabel(pipelineModel.transform(dataset))
}
@@ -232,10 +234,10 @@ class RFormulaModel private[feature](
override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
- private def transformLabel(dataset: DataFrame): DataFrame = {
+ private def transformLabel(dataset: Dataset[_]): DataFrame = {
val labelName = resolvedFormula.label
if (hasLabelCol(dataset.schema)) {
- dataset
+ dataset.toDF
} else if (dataset.schema.exists(_.name == labelName)) {
dataset.schema(labelName).dataType match {
case _: NumericType | BooleanType =>
@@ -246,7 +248,7 @@ class RFormulaModel private[feature](
} else {
// Ignore the label field. This is a hack so that this transformer can also work on test
// datasets in a Pipeline.
- dataset
+ dataset.toDF
}
}
@@ -323,7 +325,7 @@ private class ColumnPruner(override val uid: String, val columnsToPrune: Set[Str
def this(columnsToPrune: Set[String]) =
this(Identifiable.randomUID("columnPruner"), columnsToPrune)
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
dataset.select(columnsToKeep.map(dataset.col): _*)
}
@@ -396,7 +398,7 @@ private class VectorAttributeRewriter(
def this(vectorCol: String, prefixesToRewrite: Map[String, String]) =
this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite)
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
val metadata = {
val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
val attrs = group.attributes.get.map { attr =>