From e7791c4f69aaa150e6ddb30b6d4ba2b0ea3c7807 Mon Sep 17 00:00:00 2001 From: "Sheamus K. Parkes" Date: Wed, 20 Apr 2016 13:45:14 -0700 Subject: [SPARK-13842] [PYSPARK] pyspark.sql.types.StructType accessor enhancements ## What changes were proposed in this pull request? Expand the possible ways to interact with the contents of a `pyspark.sql.types.StructType` instance. - Iterating a `StructType` will iterate its fields - `[field.name for field in my_structtype]` - Indexing with a string will return a field by name - `my_structtype['my_field_name']` - Indexing with an integer will return a field by position - `my_structtype[0]` - Indexing with a slice will return a new `StructType` with just the chosen fields: - `my_structtype[1:3]` - The length is the number of fields (should also provide "truthiness" for free) - `len(my_structtype) == 2` ## How was this patch tested? Extended the unit test coverage in the accompanying `tests.py`. Author: Sheamus K. Parkes Closes #12251 from skparkes/pyspark-structtype-enhance. --- python/pyspark/sql/tests.py | 23 +++++++++++++++++++---- python/pyspark/sql/types.py | 44 +++++++++++++++++++++++++++++++++++++++----- 2 files changed, 58 insertions(+), 9 deletions(-) (limited to 'python/pyspark/sql') diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1e864b4cd1..3b1b2948e9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -802,11 +802,26 @@ class SQLTests(ReusedPySparkTestCase): self.assertNotEqual(struct1, struct2) # Catch exception raised during improper construction - try: + with self.assertRaises(ValueError): struct1 = StructType().add("name") - self.assertEqual(1, 0) - except ValueError: - self.assertEqual(1, 1) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + for field in struct1: + self.assertIsInstance(field, StructField) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + self.assertEqual(len(struct1), 2) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + self.assertIs(struct1["f1"], struct1.fields[0]) + self.assertIs(struct1[0], struct1.fields[0]) + self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1])) + with self.assertRaises(KeyError): + not_a_field = struct1["f9"] + with self.assertRaises(IndexError): + not_a_field = struct1[9] + with self.assertRaises(TypeError): + not_a_field = struct1[9.9] def test_metadata_null(self): from pyspark.sql.types import StructType, StringType, StructField diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 734c1533a2..f7cd4b80ca 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -442,6 +442,15 @@ class StructType(DataType): """Struct type, consisting of a list of :class:`StructField`. This is the data type representing a :class:`Row`. + + Iterating a :class:`StructType` will iterate its :class:`StructField`s. + A contained :class:`StructField` can be accessed by name or position. + + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct1["f1"] + StructField(f1,StringType,true) + >>> struct1[0] + StructField(f1,StringType,true) """ def __init__(self, fields=None): """ @@ -463,7 +472,7 @@ class StructType(DataType): self.names = [f.name for f in fields] assert all(isinstance(f, StructField) for f in fields),\ "fields should be a list of StructField" - self._needSerializeAnyField = any(f.needConversion() for f in self.fields) + self._needSerializeAnyField = any(f.needConversion() for f in self) def add(self, field, data_type=None, nullable=True, metadata=None): """ @@ -508,19 +517,44 @@ class StructType(DataType): data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) self.names.append(field) - self._needSerializeAnyField = any(f.needConversion() for f in self.fields) + self._needSerializeAnyField = any(f.needConversion() for f in self) return self + def __iter__(self): + """Iterate the fields""" + return iter(self.fields) + + def __len__(self): + """Return the number of fields.""" + return len(self.fields) + + def __getitem__(self, key): + """Access fields by name or slice.""" + if isinstance(key, str): + for field in self: + if field.name == key: + return field + raise KeyError('No StructField named {0}'.format(key)) + elif isinstance(key, int): + try: + return self.fields[key] + except IndexError: + raise IndexError('StructType index out of range') + elif isinstance(key, slice): + return StructType(self.fields[key]) + else: + raise TypeError('StructType keys should be strings, integers or slices') + def simpleString(self): - return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields)) + return 'struct<%s>' % (','.join(f.simpleString() for f in self)) def __repr__(self): return ("StructType(List(%s))" % - ",".join(str(field) for field in self.fields)) + ",".join(str(field) for field in self)) def jsonValue(self): return {"type": self.typeName(), - "fields": [f.jsonValue() for f in self.fields]} + "fields": [f.jsonValue() for f in self]} @classmethod def fromJson(cls, json): -- cgit v1.2.3