diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 12a76dbbfb..3ac6c77669 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -29,7 +29,7 @@ import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.VectorUDT -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types._ /** @@ -103,7 +103,8 @@ class RFormula(override val uid: String) RFormulaParser.parse($(formula)).hasIntercept } - override def fit(dataset: DataFrame): RFormulaModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): RFormulaModel = { require(isDefined(formula), "Formula must be defined first.") val parsedFormula = RFormulaParser.parse($(formula)) val resolvedFormula = parsedFormula.resolve(dataset.schema) @@ -204,7 +205,8 @@ class RFormulaModel private[feature]( private[ml] val pipelineModel: PipelineModel) extends Model[RFormulaModel] with RFormulaBase with MLWritable { - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { checkCanTransform(dataset.schema) transformLabel(pipelineModel.transform(dataset)) } @@ -232,10 +234,10 @@ class RFormulaModel private[feature]( override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)" - private def transformLabel(dataset: DataFrame): DataFrame = { + private def transformLabel(dataset: Dataset[_]): DataFrame = { val labelName = resolvedFormula.label if (hasLabelCol(dataset.schema)) { - dataset + dataset.toDF } else if (dataset.schema.exists(_.name == labelName)) { dataset.schema(labelName).dataType match { case _: NumericType | BooleanType => @@ -246,7 +248,7 @@ class RFormulaModel private[feature]( } else { // Ignore the label field. This is a hack so that this transformer can also work on test // datasets in a Pipeline. - dataset + dataset.toDF } } @@ -323,7 +325,7 @@ private class ColumnPruner(override val uid: String, val columnsToPrune: Set[Str def this(columnsToPrune: Set[String]) = this(Identifiable.randomUID("columnPruner"), columnsToPrune) - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) dataset.select(columnsToKeep.map(dataset.col): _*) } @@ -396,7 +398,7 @@ private class VectorAttributeRewriter( def this(vectorCol: String, prefixesToRewrite: Map[String, String]) = this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite) - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val metadata = { val group = AttributeGroup.fromStructField(dataset.schema(vectorCol)) val attrs = group.attributes.get.map { attr => |