diff options
-rw-r--r-- | python/pyspark/streaming/kafka.py | 10 | ||||
-rw-r--r-- | python/pyspark/streaming/tests.py | 10 |
2 files changed, 20 insertions, 0 deletions
diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index b35bbaf404..06e159172a 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -254,6 +254,16 @@ class TopicAndPartition(object): def _jTopicAndPartition(self, helper): return helper.createTopicAndPartition(self._topic, self._partition) + def __eq__(self, other): + if isinstance(other, self.__class__): + return (self._topic == other._topic + and self._partition == other._partition) + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + class Broker(object): """ diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 2c908daa8b..f7fa481d50 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -898,6 +898,16 @@ class KafkaStreamTests(PySparkStreamingTestCase): self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + def test_topic_and_partition_equality(self): + topic_and_partition_a = TopicAndPartition("foo", 0) + topic_and_partition_b = TopicAndPartition("foo", 0) + topic_and_partition_c = TopicAndPartition("bar", 0) + topic_and_partition_d = TopicAndPartition("foo", 1) + + self.assertEqual(topic_and_partition_a, topic_and_partition_b) + self.assertNotEqual(topic_and_partition_a, topic_and_partition_c) + self.assertNotEqual(topic_and_partition_a, topic_and_partition_d) + class FlumeStreamTests(PySparkStreamingTestCase): timeout = 20 # seconds |