aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/feature.py
diff options
context:
space:
mode:
authorSandeep Singh <sandeep@techaddict.me>2016-11-30 11:33:15 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-11-30 11:33:15 +0200
commitfe854f2e4fb2fa1a1c501f11030e36f489ca546f (patch)
tree9b6179709687e5e10cf742eb61b82189dfce76de /python/pyspark/ml/feature.py
parent56c82edabd62db9e936bb9afcf300faf8ef39362 (diff)
downloadspark-fe854f2e4fb2fa1a1c501f11030e36f489ca546f.tar.gz
spark-fe854f2e4fb2fa1a1c501f11030e36f489ca546f.tar.bz2
spark-fe854f2e4fb2fa1a1c501f11030e36f489ca546f.zip
[SPARK-18366][PYSPARK][ML] Add handleInvalid to Pyspark for QuantileDiscretizer and Bucketizer
## What changes were proposed in this pull request? added the new handleInvalid param for these transformers to Python to maintain API parity. ## How was this patch tested? existing tests testing is done with new doctests Author: Sandeep Singh <sandeep@techaddict.me> Closes #15817 from techaddict/SPARK-18366.
Diffstat (limited to 'python/pyspark/ml/feature.py')
-rwxr-xr-xpython/pyspark/ml/feature.py85
1 files changed, 71 insertions, 14 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index aada38d1ad..1d62b32534 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -125,10 +125,13 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
"""
Maps a column of continuous features to a column of feature buckets.
- >>> df = spark.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"])
+ >>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)]
+ >>> df = spark.createDataFrame(values, ["values"])
>>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")],
... inputCol="values", outputCol="buckets")
- >>> bucketed = bucketizer.transform(df).collect()
+ >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect()
+ >>> len(bucketed)
+ 6
>>> bucketed[0].buckets
0.0
>>> bucketed[1].buckets
@@ -144,6 +147,9 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
>>> loadedBucketizer = Bucketizer.load(bucketizerPath)
>>> loadedBucketizer.getSplits() == bucketizer.getSplits()
True
+ >>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect()
+ >>> len(bucketed)
+ 4
.. versionadded:: 1.4.0
"""
@@ -158,21 +164,28 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
"splits specified will be treated as errors.",
typeConverter=TypeConverters.toListFloat)
+ handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " +
+ "Options are skip (filter out rows with invalid values), " +
+ "error (throw an error), or keep (keep invalid values in a special " +
+ "additional bucket).",
+ typeConverter=TypeConverters.toString)
+
@keyword_only
- def __init__(self, splits=None, inputCol=None, outputCol=None):
+ def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"):
"""
- __init__(self, splits=None, inputCol=None, outputCol=None)
+ __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error")
"""
super(Bucketizer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid)
+ self._setDefault(handleInvalid="error")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.4.0")
- def setParams(self, splits=None, inputCol=None, outputCol=None):
+ def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"):
"""
- setParams(self, splits=None, inputCol=None, outputCol=None)
+ setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error")
Sets params for this Bucketizer.
"""
kwargs = self.setParams._input_kwargs
@@ -192,6 +205,20 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
"""
return self.getOrDefault(self.splits)
+ @since("2.1.0")
+ def setHandleInvalid(self, value):
+ """
+ Sets the value of :py:attr:`handleInvalid`.
+ """
+ return self._set(handleInvalid=value)
+
+ @since("2.1.0")
+ def getHandleInvalid(self):
+ """
+ Gets the value of :py:attr:`handleInvalid` or its default value.
+ """
+ return self.getOrDefault(self.handleInvalid)
+
@inherit_doc
class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
@@ -1157,12 +1184,17 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab
:py:attr:`relativeError` parameter.
The lower and upper bin bounds will be `-Infinity` and `+Infinity`, covering all real values.
- >>> df = spark.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"])
+ >>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)]
+ >>> df = spark.createDataFrame(values, ["values"])
>>> qds = QuantileDiscretizer(numBuckets=2,
- ... inputCol="values", outputCol="buckets", relativeError=0.01)
+ ... inputCol="values", outputCol="buckets", relativeError=0.01, handleInvalid="error")
>>> qds.getRelativeError()
0.01
>>> bucketizer = qds.fit(df)
+ >>> qds.setHandleInvalid("keep").fit(df).transform(df).count()
+ 6
+ >>> qds.setHandleInvalid("skip").fit(df).transform(df).count()
+ 4
>>> splits = bucketizer.getSplits()
>>> splits[0]
-inf
@@ -1190,23 +1222,33 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab
"Must be in the range [0, 1].",
typeConverter=TypeConverters.toFloat)
+ handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " +
+ "Options are skip (filter out rows with invalid values), " +
+ "error (throw an error), or keep (keep invalid values in a special " +
+ "additional bucket).",
+ typeConverter=TypeConverters.toString)
+
@keyword_only
- def __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001):
+ def __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001,
+ handleInvalid="error"):
"""
- __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001)
+ __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, \
+ handleInvalid="error")
"""
super(QuantileDiscretizer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer",
self.uid)
- self._setDefault(numBuckets=2, relativeError=0.001)
+ self._setDefault(numBuckets=2, relativeError=0.001, handleInvalid="error")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("2.0.0")
- def setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001):
+ def setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001,
+ handleInvalid="error"):
"""
- setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001)
+ setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, \
+ handleInvalid="error")
Set the params for the QuantileDiscretizer
"""
kwargs = self.setParams._input_kwargs
@@ -1240,13 +1282,28 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab
"""
return self.getOrDefault(self.relativeError)
+ @since("2.1.0")
+ def setHandleInvalid(self, value):
+ """
+ Sets the value of :py:attr:`handleInvalid`.
+ """
+ return self._set(handleInvalid=value)
+
+ @since("2.1.0")
+ def getHandleInvalid(self):
+ """
+ Gets the value of :py:attr:`handleInvalid` or its default value.
+ """
+ return self.getOrDefault(self.handleInvalid)
+
def _create_model(self, java_model):
"""
Private method to convert the java_model to a Python model.
"""
return Bucketizer(splits=list(java_model.getSplits()),
inputCol=self.getInputCol(),
- outputCol=self.getOutputCol())
+ outputCol=self.getOutputCol(),
+ handleInvalid=self.getHandleInvalid())
@inherit_doc