diff options
author | Daoyuan Wang <daoyuan.wang@intel.com> | 2015-05-12 10:23:41 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-05-12 10:23:41 -0700 |
commit | d86ce845840a92b4dde7975082738ed94ab8c570 (patch) | |
tree | 7c1f437169cf8132bd5c6b70bc374bea717e13ab /python/pyspark/sql/dataframe.py | |
parent | ec6f2a9774167014566fb9608ee4394d2ce5fd6a (diff) | |
download | spark-d86ce845840a92b4dde7975082738ed94ab8c570.tar.gz spark-d86ce845840a92b4dde7975082738ed94ab8c570.tar.bz2 spark-d86ce845840a92b4dde7975082738ed94ab8c570.zip |
[SPARK-6876] [PySpark] [SQL] add DataFrame na.replace in pyspark
Author: Daoyuan Wang <daoyuan.wang@intel.com>
Closes #6003 from adrian-wang/pynareplace and squashes the following commits:
672efba [Daoyuan Wang] remove py2.7 feature
4a148f7 [Daoyuan Wang] to_replace support dict, value support single value, and add full tests
9e232e7 [Daoyuan Wang] rename scala map
af0268a [Daoyuan Wang] remove na
63ac579 [Daoyuan Wang] add na.replace in pyspark
Diffstat (limited to 'python/pyspark/sql/dataframe.py')
-rw-r--r-- | python/pyspark/sql/dataframe.py | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 72180f6d05..078acfdf7e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -578,6 +578,10 @@ class DataFrame(object): """Return a JVM Seq of Columns from a list of Column or names""" return _to_seq(self.sql_ctx._sc, cols, converter) + def _jmap(self, jm): + """Return a JVM Scala Map from a dict""" + return _to_scala_map(self.sql_ctx._sc, jm) + def _jcols(self, *cols): """Return a JVM Seq of Columns from a list of Column or column names @@ -924,6 +928,80 @@ class DataFrame(object): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) + def replace(self, to_replace, value, subset=None): + """Returns a new :class:`DataFrame` replacing a value with another value. + + :param to_replace: int, long, float, string, or list. + Value to be replaced. + If the value is a dict, then `value` is ignored and `to_replace` must be a + mapping from column name (string) to replacement value. The value to be + replaced must be an int, long, float, or string. + :param value: int, long, float, string, or list. + Value to use to replace holes. + The replacement value must be an int, long, float, or string. If `value` is a + list or tuple, `value` should be of the same length with `to_replace`. + :param subset: optional list of column names to consider. + Columns specified in subset that do not have matching data type are ignored. + For example, if `value` is a string, and subset contains a non-string column, + then the non-string column is simply ignored. + >>> df4.replace(10, 20).show() + +----+------+-----+ + | age|height| name| + +----+------+-----+ + | 20| 80|Alice| + | 5| null| Bob| + |null| null| Tom| + |null| null| null| + +----+------+-----+ + + >>> df4.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show() + +----+------+----+ + | age|height|name| + +----+------+----+ + | 10| 80| A| + | 5| null| B| + |null| null| Tom| + |null| null|null| + +----+------+----+ + """ + if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)): + raise ValueError( + "to_replace should be a float, int, long, string, list, tuple, or dict") + + if not isinstance(value, (float, int, long, basestring, list, tuple)): + raise ValueError("value should be a float, int, long, string, list, or tuple") + + rep_dict = dict() + + if isinstance(to_replace, (float, int, long, basestring)): + to_replace = [to_replace] + + if isinstance(to_replace, tuple): + to_replace = list(to_replace) + + if isinstance(value, tuple): + value = list(value) + + if isinstance(to_replace, list) and isinstance(value, list): + if len(to_replace) != len(value): + raise ValueError("to_replace and value lists should be of the same length") + rep_dict = dict(zip(to_replace, value)) + elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)): + rep_dict = dict([(tr, value) for tr in to_replace]) + elif isinstance(to_replace, dict): + rep_dict = to_replace + + if subset is None: + return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx) + elif isinstance(subset, basestring): + subset = [subset] + + if not isinstance(subset, (list, tuple)): + raise ValueError("subset should be a list or tuple of column names") + + return DataFrame( + self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) + def corr(self, col1, col2, method=None): """ Calculates the correlation of two columns of a DataFrame as a double value. Currently only @@ -1226,6 +1304,13 @@ def _to_seq(sc, cols, converter=None): return sc._jvm.PythonUtils.toSeq(cols) +def _to_scala_map(sc, jm): + """ + Convert a dict into a JVM Map. + """ + return sc._jvm.PythonUtils.toScalaMap(jm) + + def _unary_op(name, doc="unary operator"): """ Create a method for given unary operator """ def _(self): |