aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzero323 <zero323@users.noreply.github.com>2017-04-05 11:47:40 -0700
committerHolden Karau <holden@us.ibm.com>2017-04-05 11:47:40 -0700
commite2773996b8d1c0214d9ffac634a059b4923caf7b (patch)
treef04cc80d708c52953b48fb2a1165a5c005783206
parenta2d8d767d933321426a4eb9df1583e017722d7d6 (diff)
downloadspark-e2773996b8d1c0214d9ffac634a059b4923caf7b.tar.gz
spark-e2773996b8d1c0214d9ffac634a059b4923caf7b.tar.bz2
spark-e2773996b8d1c0214d9ffac634a059b4923caf7b.zip
[SPARK-19454][PYTHON][SQL] DataFrame.replace improvements
## What changes were proposed in this pull request? - Allows skipping `value` argument if `to_replace` is a `dict`: ```python df = sc.parallelize([("Alice", 1, 3.0)]).toDF() df.replace({"Alice": "Bob"}).show() ```` - Adds validation step to ensure homogeneous values / replacements. - Simplifies internal control flow. - Improves unit tests coverage. ## How was this patch tested? Existing unit tests, additional unit tests, manual testing. Author: zero323 <zero323@users.noreply.github.com> Closes #16793 from zero323/SPARK-19454.
-rw-r--r--python/pyspark/sql/dataframe.py81
-rw-r--r--python/pyspark/sql/tests.py72
2 files changed, 128 insertions, 25 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index a24512f53c..774caf53f3 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -25,6 +25,8 @@ if sys.version >= '3':
else:
from itertools import imap as map
+import warnings
+
from pyspark import copy_func, since
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
@@ -1281,7 +1283,7 @@ class DataFrame(object):
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
@since(1.4)
- def replace(self, to_replace, value, subset=None):
+ def replace(self, to_replace, value=None, subset=None):
"""Returns a new :class:`DataFrame` replacing a value with another value.
:func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are
aliases of each other.
@@ -1326,43 +1328,72 @@ class DataFrame(object):
|null| null|null|
+----+------+----+
"""
- if not isinstance(to_replace, (float, int, long, basestring, list, tuple, dict)):
+ # Helper functions
+ def all_of(types):
+ """Given a type or tuple of types and a sequence of xs
+ check if each x is instance of type(s)
+
+ >>> all_of(bool)([True, False])
+ True
+ >>> all_of(basestring)(["a", 1])
+ False
+ """
+ def all_of_(xs):
+ return all(isinstance(x, types) for x in xs)
+ return all_of_
+
+ all_of_bool = all_of(bool)
+ all_of_str = all_of(basestring)
+ all_of_numeric = all_of((float, int, long))
+
+ # Validate input types
+ valid_types = (bool, float, int, long, basestring, list, tuple)
+ if not isinstance(to_replace, valid_types + (dict, )):
raise ValueError(
- "to_replace should be a float, int, long, string, list, tuple, or dict")
+ "to_replace should be a float, int, long, string, list, tuple, or dict. "
+ "Got {0}".format(type(to_replace)))
- if not isinstance(value, (float, int, long, basestring, list, tuple)):
- raise ValueError("value should be a float, int, long, string, list, or tuple")
+ if not isinstance(value, valid_types) and not isinstance(to_replace, dict):
+ raise ValueError("If to_replace is not a dict, value should be "
+ "a float, int, long, string, list, or tuple. "
+ "Got {0}".format(type(value)))
+
+ if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)):
+ if len(to_replace) != len(value):
+ raise ValueError("to_replace and value lists should be of the same length. "
+ "Got {0} and {1}".format(len(to_replace), len(value)))
- rep_dict = dict()
+ if not (subset is None or isinstance(subset, (list, tuple, basestring))):
+ raise ValueError("subset should be a list or tuple of column names, "
+ "column name or None. Got {0}".format(type(subset)))
+ # Reshape input arguments if necessary
if isinstance(to_replace, (float, int, long, basestring)):
to_replace = [to_replace]
- if isinstance(to_replace, tuple):
- to_replace = list(to_replace)
+ if isinstance(value, (float, int, long, basestring)):
+ value = [value for _ in range(len(to_replace))]
- if isinstance(value, tuple):
- value = list(value)
-
- if isinstance(to_replace, list) and isinstance(value, list):
- if len(to_replace) != len(value):
- raise ValueError("to_replace and value lists should be of the same length")
- rep_dict = dict(zip(to_replace, value))
- elif isinstance(to_replace, list) and isinstance(value, (float, int, long, basestring)):
- rep_dict = dict([(tr, value) for tr in to_replace])
- elif isinstance(to_replace, dict):
+ if isinstance(to_replace, dict):
rep_dict = to_replace
+ if value is not None:
+ warnings.warn("to_replace is a dict and value is not None. value will be ignored.")
+ else:
+ rep_dict = dict(zip(to_replace, value))
- if subset is None:
- return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx)
- elif isinstance(subset, basestring):
+ if isinstance(subset, basestring):
subset = [subset]
- if not isinstance(subset, (list, tuple)):
- raise ValueError("subset should be a list or tuple of column names")
+ # Verify we were not passed in mixed type generics."
+ if not any(all_of_type(rep_dict.keys()) and all_of_type(rep_dict.values())
+ for all_of_type in [all_of_bool, all_of_str, all_of_numeric]):
+ raise ValueError("Mixed type replacements are not supported")
- return DataFrame(
- self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx)
+ if subset is None:
+ return DataFrame(self._jdf.na().replace('*', rep_dict), self.sql_ctx)
+ else:
+ return DataFrame(
+ self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx)
@since(2.0)
def approxQuantile(self, col, probabilities, relativeError):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index db41b4edb6..2b2444304e 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1779,6 +1779,78 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(row.age, 10)
self.assertEqual(row.height, None)
+ # replace with lists
+ row = self.spark.createDataFrame(
+ [(u'Alice', 10, 80.1)], schema).replace([u'Alice'], [u'Ann']).first()
+ self.assertTupleEqual(row, (u'Ann', 10, 80.1))
+
+ # replace with dict
+ row = self.spark.createDataFrame(
+ [(u'Alice', 10, 80.1)], schema).replace({10: 11}).first()
+ self.assertTupleEqual(row, (u'Alice', 11, 80.1))
+
+ # test backward compatibility with dummy value
+ dummy_value = 1
+ row = self.spark.createDataFrame(
+ [(u'Alice', 10, 80.1)], schema).replace({'Alice': 'Bob'}, dummy_value).first()
+ self.assertTupleEqual(row, (u'Bob', 10, 80.1))
+
+ # test dict with mixed numerics
+ row = self.spark.createDataFrame(
+ [(u'Alice', 10, 80.1)], schema).replace({10: -10, 80.1: 90.5}).first()
+ self.assertTupleEqual(row, (u'Alice', -10, 90.5))
+
+ # replace with tuples
+ row = self.spark.createDataFrame(
+ [(u'Alice', 10, 80.1)], schema).replace((u'Alice', ), (u'Bob', )).first()
+ self.assertTupleEqual(row, (u'Bob', 10, 80.1))
+
+ # replace multiple columns
+ row = self.spark.createDataFrame(
+ [(u'Alice', 10, 80.0)], schema).replace((10, 80.0), (20, 90)).first()
+ self.assertTupleEqual(row, (u'Alice', 20, 90.0))
+
+ # test for mixed numerics
+ row = self.spark.createDataFrame(
+ [(u'Alice', 10, 80.0)], schema).replace((10, 80), (20, 90.5)).first()
+ self.assertTupleEqual(row, (u'Alice', 20, 90.5))
+
+ row = self.spark.createDataFrame(
+ [(u'Alice', 10, 80.0)], schema).replace({10: 20, 80: 90.5}).first()
+ self.assertTupleEqual(row, (u'Alice', 20, 90.5))
+
+ # replace with boolean
+ row = (self
+ .spark.createDataFrame([(u'Alice', 10, 80.0)], schema)
+ .selectExpr("name = 'Bob'", 'age <= 15')
+ .replace(False, True).first())
+ self.assertTupleEqual(row, (True, True))
+
+ # should fail if subset is not list, tuple or None
+ with self.assertRaises(ValueError):
+ self.spark.createDataFrame(
+ [(u'Alice', 10, 80.1)], schema).replace({10: 11}, subset=1).first()
+
+ # should fail if to_replace and value have different length
+ with self.assertRaises(ValueError):
+ self.spark.createDataFrame(
+ [(u'Alice', 10, 80.1)], schema).replace(["Alice", "Bob"], ["Eve"]).first()
+
+ # should fail if when received unexpected type
+ with self.assertRaises(ValueError):
+ from datetime import datetime
+ self.spark.createDataFrame(
+ [(u'Alice', 10, 80.1)], schema).replace(datetime.now(), datetime.now()).first()
+
+ # should fail if provided mixed type replacements
+ with self.assertRaises(ValueError):
+ self.spark.createDataFrame(
+ [(u'Alice', 10, 80.1)], schema).replace(["Alice", 10], ["Eve", 20]).first()
+
+ with self.assertRaises(ValueError):
+ self.spark.createDataFrame(
+ [(u'Alice', 10, 80.1)], schema).replace({u"Alice": u"Bob", 10: 20}).first()
+
def test_capture_analysis_exception(self):
self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc"))
self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))