diff options
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala | 21 |
1 files changed, 7 insertions, 14 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index d1726917e4..d5360c9217 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -19,16 +19,14 @@ package org.apache.spark.ml.feature import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.util.parsing.combinator.RegexParsers import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.{Estimator, Model, Transformer, Pipeline, PipelineModel, PipelineStage} +import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** @@ -63,31 +61,26 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R */ val formula: Param[String] = new Param(this, "formula", "R model formula") - private var parsedFormula: Option[ParsedRFormula] = None - /** * Sets the formula to use for this transformer. Must be called before use. * @group setParam * @param value an R formula in string form (e.g. "y ~ x + z") */ - def setFormula(value: String): this.type = { - parsedFormula = Some(RFormulaParser.parse(value)) - set(formula, value) - this - } + def setFormula(value: String): this.type = set(formula, value) /** @group getParam */ def getFormula: String = $(formula) /** Whether the formula specifies fitting an intercept. */ private[ml] def hasIntercept: Boolean = { - require(parsedFormula.isDefined, "Must call setFormula() first.") - parsedFormula.get.hasIntercept + require(isDefined(formula), "Formula must be defined first.") + RFormulaParser.parse($(formula)).hasIntercept } override def fit(dataset: DataFrame): RFormulaModel = { - require(parsedFormula.isDefined, "Must call setFormula() first.") - val resolvedFormula = parsedFormula.get.resolve(dataset.schema) + require(isDefined(formula), "Formula must be defined first.") + val parsedFormula = RFormulaParser.parse($(formula)) + val resolvedFormula = parsedFormula.resolve(dataset.schema) // StringType terms and terms representing interactions need to be encoded before assembly. // TODO(ekl) add support for feature interactions val encoderStages = ArrayBuffer[PipelineStage]() |