aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/dataframe.py
diff options
context:
space:
mode:
authorDaoyuan Wang <daoyuan.wang@intel.com>2015-05-12 10:23:41 -0700
committerReynold Xin <rxin@databricks.com>2015-05-12 10:23:41 -0700
commitd86ce845840a92b4dde7975082738ed94ab8c570 (patch)
tree7c1f437169cf8132bd5c6b70bc374bea717e13ab /python/pyspark/sql/dataframe.py
parentec6f2a9774167014566fb9608ee4394d2ce5fd6a (diff)
downloadspark-d86ce845840a92b4dde7975082738ed94ab8c570.tar.gz
spark-d86ce845840a92b4dde7975082738ed94ab8c570.tar.bz2
spark-d86ce845840a92b4dde7975082738ed94ab8c570.zip
[SPARK-6876] [PySpark] [SQL] add DataFrame na.replace in pyspark
Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #6003 from adrian-wang/pynareplace and squashes the following commits: 672efba [Daoyuan Wang] remove py2.7 feature 4a148f7 [Daoyuan Wang] to_replace support dict, value support single value, and add full tests 9e232e7 [Daoyuan Wang] rename scala map af0268a [Daoyuan Wang] remove na 63ac579 [Daoyuan Wang] add na.replace in pyspark
Diffstat (limited to 'python/pyspark/sql/dataframe.py')
-rw-r--r--python/pyspark/sql/dataframe.py85
1 files changed, 85 insertions, 0 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 72180f6d05..078acfdf7e 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -578,6 +578,10 @@ class DataFrame(object):
"""Return a JVM Seq of Columns from a list of Column or names"""
return _to_seq(self.sql_ctx._sc, cols, converter)
+ def _jmap(self, jm):
+ """Return a JVM Scala Map from a dict"""
+ return _to_scala_map(self.sql_ctx._sc, jm)
+
def _jcols(self, *cols):
"""Return a JVM Seq of Columns from a list of Column or column names
@@ -924,6 +928,80 @@ class DataFrame(object):
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
+ def replace(self, to_replace, value, subset=None):
+ """Returns a new :class:`DataFrame` replacing a value with another value.
+
+ :param to_replace: int, long, float, string, or list.
+ Value to be replaced.
+ If the value is a dict, then `value` is ignored and `to_replace` must be a
+ mapping from column name (string) to replacement value. The value to be
+ replaced must be an int, long, float, or string.
+ :param value: int, long, float, string, or list.
+ Value to use to replace holes.
+ The replacement value must be an int, long, float, or string. If `value` is a
+ list or tuple, `value` should be of the same length with `to_replace`.
+ :param subset: optional list of column names to consider.
+ Columns specified in subset that do not have matching data type are ignored.
+ For example, if `value` is a string, and subset contains a non-string column,
+ then the non-string column is simply ignored.
+ >>> df4.replace(10, 20).show()
+ +----+------+-----+
+ | age|height| name|
+ +----+------+-----+
+ | 20| 80|Alice|
+ | 5| null| Bob|
+ |null| null| Tom|
+ |null| null| null|
+ +----+------+-----+
+
+ >>> df4.replace(['Alice', 'Bob'], ['A', 'B'], 'name').show()
+ +----+------+----+
+ | age|height|name|
+ +----+------+----+
+ | 10| 80| A|
+ | 5| null| B|
+ |null| null| Tom|
+ |null| null|null|
+ +----+------+----+
+ """
+ if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)):
+ raise ValueError(
+ "to_replace should be a float, int, long, string, list, tuple, or dict")
+
+ if not isinstance(value, (float, int, long, basestring, list, tuple)):
+ raise ValueError("value should be a float, int, long, string, list, or tuple")
+
+ rep_dict = dict()
+
+ if isinstance(to_replace, (float, int, long, basestring)):
+ to_replace = [to_replace]
+
+ if isinstance(to_replace, tuple):
+ to_replace = list(to_replace)
+
+ if isinstance(value, tuple):
+ value = list(value)
+
+ if isinstance(to_replace, list) and isinstance(value, list):
+ if len(to_replace) != len(value):
+ raise ValueError("to_replace and value lists should be of the same length")
+ rep_dict = dict(zip(to_replace, value))
+ elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)):
+ rep_dict = dict([(tr, value) for tr in to_replace])
+ elif isinstance(to_replace, dict):
+ rep_dict = to_replace
+
+ if subset is None:
+ return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx)
+ elif isinstance(subset, basestring):
+ subset = [subset]
+
+ if not isinstance(subset, (list, tuple)):
+ raise ValueError("subset should be a list or tuple of column names")
+
+ return DataFrame(
+ self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx)
+
def corr(self, col1, col2, method=None):
"""
Calculates the correlation of two columns of a DataFrame as a double value. Currently only
@@ -1226,6 +1304,13 @@ def _to_seq(sc, cols, converter=None):
return sc._jvm.PythonUtils.toSeq(cols)
+def _to_scala_map(sc, jm):
+ """
+ Convert a dict into a JVM Map.
+ """
+ return sc._jvm.PythonUtils.toScalaMap(jm)
+
+
def _unary_op(name, doc="unary operator"):
""" Create a method for given unary operator """
def _(self):