diff options
author | Sheamus K. Parkes <shea.parkes@milliman.com> | 2016-04-20 13:45:14 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-04-20 13:45:14 -0700 |
commit | e7791c4f69aaa150e6ddb30b6d4ba2b0ea3c7807 (patch) | |
tree | f2aea5c83cb65e37e4bd883fbd5b2155e14eeb5f /python | |
parent | 7bc948557bb6169cbeec335f8400af09375a62d3 (diff) | |
download | spark-e7791c4f69aaa150e6ddb30b6d4ba2b0ea3c7807.tar.gz spark-e7791c4f69aaa150e6ddb30b6d4ba2b0ea3c7807.tar.bz2 spark-e7791c4f69aaa150e6ddb30b6d4ba2b0ea3c7807.zip |
[SPARK-13842] [PYSPARK] pyspark.sql.types.StructType accessor enhancements
## What changes were proposed in this pull request?
Expand the possible ways to interact with the contents of a `pyspark.sql.types.StructType` instance.
- Iterating a `StructType` will iterate its fields
- `[field.name for field in my_structtype]`
- Indexing with a string will return a field by name
- `my_structtype['my_field_name']`
- Indexing with an integer will return a field by position
- `my_structtype[0]`
- Indexing with a slice will return a new `StructType` with just the chosen fields:
- `my_structtype[1:3]`
- The length is the number of fields (should also provide "truthiness" for free)
- `len(my_structtype) == 2`
## How was this patch tested?
Extended the unit test coverage in the accompanying `tests.py`.
Author: Sheamus K. Parkes <shea.parkes@milliman.com>
Closes #12251 from skparkes/pyspark-structtype-enhance.
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql/tests.py | 23 | ||||
-rw-r--r-- | python/pyspark/sql/types.py | 44 |
2 files changed, 58 insertions, 9 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1e864b4cd1..3b1b2948e9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -802,11 +802,26 @@ class SQLTests(ReusedPySparkTestCase): self.assertNotEqual(struct1, struct2) # Catch exception raised during improper construction - try: + with self.assertRaises(ValueError): struct1 = StructType().add("name") - self.assertEqual(1, 0) - except ValueError: - self.assertEqual(1, 1) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + for field in struct1: + self.assertIsInstance(field, StructField) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + self.assertEqual(len(struct1), 2) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + self.assertIs(struct1["f1"], struct1.fields[0]) + self.assertIs(struct1[0], struct1.fields[0]) + self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1])) + with self.assertRaises(KeyError): + not_a_field = struct1["f9"] + with self.assertRaises(IndexError): + not_a_field = struct1[9] + with self.assertRaises(TypeError): + not_a_field = struct1[9.9] def test_metadata_null(self): from pyspark.sql.types import StructType, StringType, StructField diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 734c1533a2..f7cd4b80ca 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -442,6 +442,15 @@ class StructType(DataType): """Struct type, consisting of a list of :class:`StructField`. This is the data type representing a :class:`Row`. + + Iterating a :class:`StructType` will iterate its :class:`StructField`s. + A contained :class:`StructField` can be accessed by name or position. + + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct1["f1"] + StructField(f1,StringType,true) + >>> struct1[0] + StructField(f1,StringType,true) """ def __init__(self, fields=None): """ @@ -463,7 +472,7 @@ class StructType(DataType): self.names = [f.name for f in fields] assert all(isinstance(f, StructField) for f in fields),\ "fields should be a list of StructField" - self._needSerializeAnyField = any(f.needConversion() for f in self.fields) + self._needSerializeAnyField = any(f.needConversion() for f in self) def add(self, field, data_type=None, nullable=True, metadata=None): """ @@ -508,19 +517,44 @@ class StructType(DataType): data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) self.names.append(field) - self._needSerializeAnyField = any(f.needConversion() for f in self.fields) + self._needSerializeAnyField = any(f.needConversion() for f in self) return self + def __iter__(self): + """Iterate the fields""" + return iter(self.fields) + + def __len__(self): + """Return the number of fields.""" + return len(self.fields) + + def __getitem__(self, key): + """Access fields by name or slice.""" + if isinstance(key, str): + for field in self: + if field.name == key: + return field + raise KeyError('No StructField named {0}'.format(key)) + elif isinstance(key, int): + try: + return self.fields[key] + except IndexError: + raise IndexError('StructType index out of range') + elif isinstance(key, slice): + return StructType(self.fields[key]) + else: + raise TypeError('StructType keys should be strings, integers or slices') + def simpleString(self): - return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields)) + return 'struct<%s>' % (','.join(f.simpleString() for f in self)) def __repr__(self): return ("StructType(List(%s))" % - ",".join(str(field) for field in self.fields)) + ",".join(str(field) for field in self)) def jsonValue(self): return {"type": self.typeName(), - "fields": [f.jsonValue() for f in self.fields]} + "fields": [f.jsonValue() for f in self]} @classmethod def fromJson(cls, json): |