aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-08-03 13:59:35 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-03 13:59:35 -0700
commite4765a46833baff1dd7465c4cf50e947de7e8f21 (patch)
tree56956773833adb17e5a0052713d5e5fc88c8ec2d /mllib
parent8ca287ebbd58985a568341b08040d0efa9d3641a (diff)
downloadspark-e4765a46833baff1dd7465c4cf50e947de7e8f21.tar.gz
spark-e4765a46833baff1dd7465c4cf50e947de7e8f21.tar.bz2
spark-e4765a46833baff1dd7465c4cf50e947de7e8f21.zip
[SPARK-9544] [MLLIB] add Python API for RFormula
Add Python API for RFormula. Similar to other feature transformers in Python. This is just a thin wrapper over the Scala implementation. ericl MechCoder Author: Xiangrui Meng <meng@databricks.com> Closes #7879 from mengxr/SPARK-9544 and squashes the following commits: 3d5ff03 [Xiangrui Meng] add an doctest for . and - 5e969a5 [Xiangrui Meng] fix pydoc 1cd41f8 [Xiangrui Meng] organize imports 3c18b10 [Xiangrui Meng] add Python API for RFormula
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala21
1 files changed, 7 insertions, 14 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index d1726917e4..d5360c9217 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -19,16 +19,14 @@ package org.apache.spark.ml.feature
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import scala.util.parsing.combinator.RegexParsers
import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.{Estimator, Model, Transformer, Pipeline, PipelineModel, PipelineStage}
+import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
/**
@@ -63,31 +61,26 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
*/
val formula: Param[String] = new Param(this, "formula", "R model formula")
- private var parsedFormula: Option[ParsedRFormula] = None
-
/**
* Sets the formula to use for this transformer. Must be called before use.
* @group setParam
* @param value an R formula in string form (e.g. "y ~ x + z")
*/
- def setFormula(value: String): this.type = {
- parsedFormula = Some(RFormulaParser.parse(value))
- set(formula, value)
- this
- }
+ def setFormula(value: String): this.type = set(formula, value)
/** @group getParam */
def getFormula: String = $(formula)
/** Whether the formula specifies fitting an intercept. */
private[ml] def hasIntercept: Boolean = {
- require(parsedFormula.isDefined, "Must call setFormula() first.")
- parsedFormula.get.hasIntercept
+ require(isDefined(formula), "Formula must be defined first.")
+ RFormulaParser.parse($(formula)).hasIntercept
}
override def fit(dataset: DataFrame): RFormulaModel = {
- require(parsedFormula.isDefined, "Must call setFormula() first.")
- val resolvedFormula = parsedFormula.get.resolve(dataset.schema)
+ require(isDefined(formula), "Formula must be defined first.")
+ val parsedFormula = RFormulaParser.parse($(formula))
+ val resolvedFormula = parsedFormula.resolve(dataset.schema)
// StringType terms and terms representing interactions need to be encoded before assembly.
// TODO(ekl) add support for feature interactions
val encoderStages = ArrayBuffer[PipelineStage]()