aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-10-13 19:44:24 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-10-13 19:44:24 -0700
commit44cbb61b34a98e3e0d8e2543a4eb6e950e0019a5 (patch)
tree3367aaa95f3c39c0ce2730521f4b7a46752339cd /python/pyspark/ml
parent9dc0ca060d5925cd666b34021e62f7b38bb3aabb (diff)
downloadspark-44cbb61b34a98e3e0d8e2543a4eb6e950e0019a5.tar.gz
spark-44cbb61b34a98e3e0d8e2543a4eb6e950e0019a5.tar.bz2
spark-44cbb61b34a98e3e0d8e2543a4eb6e950e0019a5.zip
[SPARK-15957][FOLLOW-UP][ML][PYSPARK] Add Python API for RFormula forceIndexLabel.
## What changes were proposed in this pull request? Follow-up work of #13675, add Python API for ```RFormula forceIndexLabel```. ## How was this patch tested? Unit test. Author: Yanbo Liang <ybliang8@gmail.com> Closes #15430 from yanboliang/spark-15957-python.
Diffstat (limited to 'python/pyspark/ml')
-rwxr-xr-xpython/pyspark/ml/feature.py31
-rwxr-xr-xpython/pyspark/ml/tests.py16
2 files changed, 43 insertions, 4 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 64b21caa61..a33c3e7945 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2494,21 +2494,30 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
formula = Param(Params._dummy(), "formula", "R model formula",
typeConverter=TypeConverters.toString)
+ forceIndexLabel = Param(Params._dummy(), "forceIndexLabel",
+ "Force to index label whether it is numeric or string",
+ typeConverter=TypeConverters.toBoolean)
+
@keyword_only
- def __init__(self, formula=None, featuresCol="features", labelCol="label"):
+ def __init__(self, formula=None, featuresCol="features", labelCol="label",
+ forceIndexLabel=False):
"""
- __init__(self, formula=None, featuresCol="features", labelCol="label")
+ __init__(self, formula=None, featuresCol="features", labelCol="label", \
+ forceIndexLabel=False)
"""
super(RFormula, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid)
+ self._setDefault(forceIndexLabel=False)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.5.0")
- def setParams(self, formula=None, featuresCol="features", labelCol="label"):
+ def setParams(self, formula=None, featuresCol="features", labelCol="label",
+ forceIndexLabel=False):
"""
- setParams(self, formula=None, featuresCol="features", labelCol="label")
+ setParams(self, formula=None, featuresCol="features", labelCol="label", \
+ forceIndexLabel=False)
Sets params for RFormula.
"""
kwargs = self.setParams._input_kwargs
@@ -2528,6 +2537,20 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
"""
return self.getOrDefault(self.formula)
+ @since("2.1.0")
+ def setForceIndexLabel(self, value):
+ """
+ Sets the value of :py:attr:`forceIndexLabel`.
+ """
+ return self._set(forceIndexLabel=value)
+
+ @since("2.1.0")
+ def getForceIndexLabel(self):
+ """
+ Gets the value of :py:attr:`forceIndexLabel`.
+ """
+ return self.getOrDefault(self.forceIndexLabel)
+
def _create_model(self, java_model):
return RFormulaModel(java_model)
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index e233549850..9d46cc3b4a 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -477,6 +477,22 @@ class FeatureTests(SparkSessionTestCase):
feature, expected = r
self.assertEqual(feature, expected)
+ def test_rformula_force_index_label(self):
+ df = self.spark.createDataFrame([
+ (1.0, 1.0, "a"),
+ (0.0, 2.0, "b"),
+ (1.0, 0.0, "a")], ["y", "x", "s"])
+ # Does not index label by default since it's numeric type.
+ rf = RFormula(formula="y ~ x + s")
+ model = rf.fit(df)
+ transformedDF = model.transform(df)
+ self.assertEqual(transformedDF.head().label, 1.0)
+ # Force to index label.
+ rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True)
+ model2 = rf2.fit(df)
+ transformedDF2 = model2.transform(df)
+ self.assertEqual(transformedDF2.head().label, 0.0)
+
class HasInducedError(Params):