aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-08-03 13:59:35 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-03 13:59:35 -0700
commite4765a46833baff1dd7465c4cf50e947de7e8f21 (patch)
tree56956773833adb17e5a0052713d5e5fc88c8ec2d
parent8ca287ebbd58985a568341b08040d0efa9d3641a (diff)
downloadspark-e4765a46833baff1dd7465c4cf50e947de7e8f21.tar.gz
spark-e4765a46833baff1dd7465c4cf50e947de7e8f21.tar.bz2
spark-e4765a46833baff1dd7465c4cf50e947de7e8f21.zip
[SPARK-9544] [MLLIB] add Python API for RFormula
Add Python API for RFormula. Similar to other feature transformers in Python. This is just a thin wrapper over the Scala implementation. ericl MechCoder Author: Xiangrui Meng <meng@databricks.com> Closes #7879 from mengxr/SPARK-9544 and squashes the following commits: 3d5ff03 [Xiangrui Meng] add an doctest for . and - 5e969a5 [Xiangrui Meng] fix pydoc 1cd41f8 [Xiangrui Meng] organize imports 3c18b10 [Xiangrui Meng] add Python API for RFormula
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala21
-rw-r--r--python/pyspark/ml/feature.py85
2 files changed, 91 insertions, 15 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index d1726917e4..d5360c9217 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -19,16 +19,14 @@ package org.apache.spark.ml.feature
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import scala.util.parsing.combinator.RegexParsers
import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.{Estimator, Model, Transformer, Pipeline, PipelineModel, PipelineStage}
+import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
/**
@@ -63,31 +61,26 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
*/
val formula: Param[String] = new Param(this, "formula", "R model formula")
- private var parsedFormula: Option[ParsedRFormula] = None
-
/**
* Sets the formula to use for this transformer. Must be called before use.
* @group setParam
* @param value an R formula in string form (e.g. "y ~ x + z")
*/
- def setFormula(value: String): this.type = {
- parsedFormula = Some(RFormulaParser.parse(value))
- set(formula, value)
- this
- }
+ def setFormula(value: String): this.type = set(formula, value)
/** @group getParam */
def getFormula: String = $(formula)
/** Whether the formula specifies fitting an intercept. */
private[ml] def hasIntercept: Boolean = {
- require(parsedFormula.isDefined, "Must call setFormula() first.")
- parsedFormula.get.hasIntercept
+ require(isDefined(formula), "Formula must be defined first.")
+ RFormulaParser.parse($(formula)).hasIntercept
}
override def fit(dataset: DataFrame): RFormulaModel = {
- require(parsedFormula.isDefined, "Must call setFormula() first.")
- val resolvedFormula = parsedFormula.get.resolve(dataset.schema)
+ require(isDefined(formula), "Formula must be defined first.")
+ val parsedFormula = RFormulaParser.parse($(formula))
+ val resolvedFormula = parsedFormula.resolve(dataset.schema)
// StringType terms and terms representing interactions need to be encoded before assembly.
// TODO(ekl) add support for feature interactions
val encoderStages = ArrayBuffer[PipelineStage]()
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 015e7a9d49..3f04c41ac5 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -24,7 +24,7 @@ from pyspark.mllib.common import inherit_doc
__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder',
'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel',
'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer',
- 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel']
+ 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel']
@inherit_doc
@@ -1110,6 +1110,89 @@ class PCAModel(JavaModel):
"""
+@inherit_doc
+class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
+ """
+ .. note:: Experimental
+
+ Implements the transforms required for fitting a dataset against an
+ R model formula. Currently we support a limited subset of the R
+ operators, including '~', '+', '-', and '.'. Also see the R formula
+ docs:
+ http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
+
+ >>> df = sqlContext.createDataFrame([
+ ... (1.0, 1.0, "a"),
+ ... (0.0, 2.0, "b"),
+ ... (0.0, 0.0, "a")
+ ... ], ["y", "x", "s"])
+ >>> rf = RFormula(formula="y ~ x + s")
+ >>> rf.fit(df).transform(df).show()
+ +---+---+---+---------+-----+
+ | y| x| s| features|label|
+ +---+---+---+---------+-----+
+ |1.0|1.0| a|[1.0,1.0]| 1.0|
+ |0.0|2.0| b|[2.0,0.0]| 0.0|
+ |0.0|0.0| a|[0.0,1.0]| 0.0|
+ +---+---+---+---------+-----+
+ ...
+ >>> rf.fit(df, {rf.formula: "y ~ . - s"}).transform(df).show()
+ +---+---+---+--------+-----+
+ | y| x| s|features|label|
+ +---+---+---+--------+-----+
+ |1.0|1.0| a| [1.0]| 1.0|
+ |0.0|2.0| b| [2.0]| 0.0|
+ |0.0|0.0| a| [0.0]| 0.0|
+ +---+---+---+--------+-----+
+ ...
+ """
+
+ # a placeholder to make it appear in the generated doc
+ formula = Param(Params._dummy(), "formula", "R model formula")
+
+ @keyword_only
+ def __init__(self, formula=None, featuresCol="features", labelCol="label"):
+ """
+ __init__(self, formula=None, featuresCol="features", labelCol="label")
+ """
+ super(RFormula, self).__init__()
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid)
+ self.formula = Param(self, "formula", "R model formula")
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, formula=None, featuresCol="features", labelCol="label"):
+ """
+ setParams(self, formula=None, featuresCol="features", labelCol="label")
+ Sets params for RFormula.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ def setFormula(self, value):
+ """
+ Sets the value of :py:attr:`formula`.
+ """
+ self._paramMap[self.formula] = value
+ return self
+
+ def getFormula(self):
+ """
+ Gets the value of :py:attr:`formula`.
+ """
+ return self.getOrDefault(self.formula)
+
+ def _create_model(self, java_model):
+ return RFormulaModel(java_model)
+
+
+class RFormulaModel(JavaModel):
+ """
+ Model fitted by :py:class:`RFormula`.
+ """
+
+
if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext