diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-10-13 19:44:24 -0700 |
---|---|---|
committer | Yanbo Liang <ybliang8@gmail.com> | 2016-10-13 19:44:24 -0700 |
commit | 44cbb61b34a98e3e0d8e2543a4eb6e950e0019a5 (patch) | |
tree | 3367aaa95f3c39c0ce2730521f4b7a46752339cd /python/pyspark/ml | |
parent | 9dc0ca060d5925cd666b34021e62f7b38bb3aabb (diff) | |
download | spark-44cbb61b34a98e3e0d8e2543a4eb6e950e0019a5.tar.gz spark-44cbb61b34a98e3e0d8e2543a4eb6e950e0019a5.tar.bz2 spark-44cbb61b34a98e3e0d8e2543a4eb6e950e0019a5.zip |
[SPARK-15957][FOLLOW-UP][ML][PYSPARK] Add Python API for RFormula forceIndexLabel.
## What changes were proposed in this pull request?
Follow-up work of #13675, add Python API for ```RFormula forceIndexLabel```.
## How was this patch tested?
Unit test.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #15430 from yanboliang/spark-15957-python.
Diffstat (limited to 'python/pyspark/ml')
-rwxr-xr-x | python/pyspark/ml/feature.py | 31 | ||||
-rwxr-xr-x | python/pyspark/ml/tests.py | 16 |
2 files changed, 43 insertions, 4 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 64b21caa61..a33c3e7945 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2494,21 +2494,30 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM formula = Param(Params._dummy(), "formula", "R model formula", typeConverter=TypeConverters.toString) + forceIndexLabel = Param(Params._dummy(), "forceIndexLabel", + "Force to index label whether it is numeric or string", + typeConverter=TypeConverters.toBoolean) + @keyword_only - def __init__(self, formula=None, featuresCol="features", labelCol="label"): + def __init__(self, formula=None, featuresCol="features", labelCol="label", + forceIndexLabel=False): """ - __init__(self, formula=None, featuresCol="features", labelCol="label") + __init__(self, formula=None, featuresCol="features", labelCol="label", \ + forceIndexLabel=False) """ super(RFormula, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) + self._setDefault(forceIndexLabel=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.5.0") - def setParams(self, formula=None, featuresCol="features", labelCol="label"): + def setParams(self, formula=None, featuresCol="features", labelCol="label", + forceIndexLabel=False): """ - setParams(self, formula=None, featuresCol="features", labelCol="label") + setParams(self, formula=None, featuresCol="features", labelCol="label", \ + forceIndexLabel=False) Sets params for RFormula. """ kwargs = self.setParams._input_kwargs @@ -2528,6 +2537,20 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM """ return self.getOrDefault(self.formula) + @since("2.1.0") + def setForceIndexLabel(self, value): + """ + Sets the value of :py:attr:`forceIndexLabel`. + """ + return self._set(forceIndexLabel=value) + + @since("2.1.0") + def getForceIndexLabel(self): + """ + Gets the value of :py:attr:`forceIndexLabel`. + """ + return self.getOrDefault(self.forceIndexLabel) + def _create_model(self, java_model): return RFormulaModel(java_model) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index e233549850..9d46cc3b4a 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -477,6 +477,22 @@ class FeatureTests(SparkSessionTestCase): feature, expected = r self.assertEqual(feature, expected) + def test_rformula_force_index_label(self): + df = self.spark.createDataFrame([ + (1.0, 1.0, "a"), + (0.0, 2.0, "b"), + (1.0, 0.0, "a")], ["y", "x", "s"]) + # Does not index label by default since it's numeric type. + rf = RFormula(formula="y ~ x + s") + model = rf.fit(df) + transformedDF = model.transform(df) + self.assertEqual(transformedDF.head().label, 1.0) + # Force to index label. + rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True) + model2 = rf2.fit(df) + transformedDF2 = model2.transform(df) + self.assertEqual(transformedDF2.head().label, 0.0) + class HasInducedError(Params): |