aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorSheamus K. Parkes <shea.parkes@milliman.com>2016-04-20 13:45:14 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-20 13:45:14 -0700
commite7791c4f69aaa150e6ddb30b6d4ba2b0ea3c7807 (patch)
treef2aea5c83cb65e37e4bd883fbd5b2155e14eeb5f /python
parent7bc948557bb6169cbeec335f8400af09375a62d3 (diff)
downloadspark-e7791c4f69aaa150e6ddb30b6d4ba2b0ea3c7807.tar.gz
spark-e7791c4f69aaa150e6ddb30b6d4ba2b0ea3c7807.tar.bz2
spark-e7791c4f69aaa150e6ddb30b6d4ba2b0ea3c7807.zip
[SPARK-13842] [PYSPARK] pyspark.sql.types.StructType accessor enhancements
## What changes were proposed in this pull request? Expand the possible ways to interact with the contents of a `pyspark.sql.types.StructType` instance. - Iterating a `StructType` will iterate its fields - `[field.name for field in my_structtype]` - Indexing with a string will return a field by name - `my_structtype['my_field_name']` - Indexing with an integer will return a field by position - `my_structtype[0]` - Indexing with a slice will return a new `StructType` with just the chosen fields: - `my_structtype[1:3]` - The length is the number of fields (should also provide "truthiness" for free) - `len(my_structtype) == 2` ## How was this patch tested? Extended the unit test coverage in the accompanying `tests.py`. Author: Sheamus K. Parkes <shea.parkes@milliman.com> Closes #12251 from skparkes/pyspark-structtype-enhance.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/tests.py23
-rw-r--r--python/pyspark/sql/types.py44
2 files changed, 58 insertions, 9 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1e864b4cd1..3b1b2948e9 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -802,11 +802,26 @@ class SQLTests(ReusedPySparkTestCase):
self.assertNotEqual(struct1, struct2)
# Catch exception raised during improper construction
- try:
+ with self.assertRaises(ValueError):
struct1 = StructType().add("name")
- self.assertEqual(1, 0)
- except ValueError:
- self.assertEqual(1, 1)
+
+ struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+ for field in struct1:
+ self.assertIsInstance(field, StructField)
+
+ struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+ self.assertEqual(len(struct1), 2)
+
+ struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+ self.assertIs(struct1["f1"], struct1.fields[0])
+ self.assertIs(struct1[0], struct1.fields[0])
+ self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1]))
+ with self.assertRaises(KeyError):
+ not_a_field = struct1["f9"]
+ with self.assertRaises(IndexError):
+ not_a_field = struct1[9]
+ with self.assertRaises(TypeError):
+ not_a_field = struct1[9.9]
def test_metadata_null(self):
from pyspark.sql.types import StructType, StringType, StructField
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 734c1533a2..f7cd4b80ca 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -442,6 +442,15 @@ class StructType(DataType):
"""Struct type, consisting of a list of :class:`StructField`.
This is the data type representing a :class:`Row`.
+
+ Iterating a :class:`StructType` will iterate its :class:`StructField`s.
+ A contained :class:`StructField` can be accessed by name or position.
+
+ >>> struct1 = StructType([StructField("f1", StringType(), True)])
+ >>> struct1["f1"]
+ StructField(f1,StringType,true)
+ >>> struct1[0]
+ StructField(f1,StringType,true)
"""
def __init__(self, fields=None):
"""
@@ -463,7 +472,7 @@ class StructType(DataType):
self.names = [f.name for f in fields]
assert all(isinstance(f, StructField) for f in fields),\
"fields should be a list of StructField"
- self._needSerializeAnyField = any(f.needConversion() for f in self.fields)
+ self._needSerializeAnyField = any(f.needConversion() for f in self)
def add(self, field, data_type=None, nullable=True, metadata=None):
"""
@@ -508,19 +517,44 @@ class StructType(DataType):
data_type_f = data_type
self.fields.append(StructField(field, data_type_f, nullable, metadata))
self.names.append(field)
- self._needSerializeAnyField = any(f.needConversion() for f in self.fields)
+ self._needSerializeAnyField = any(f.needConversion() for f in self)
return self
+ def __iter__(self):
+ """Iterate the fields"""
+ return iter(self.fields)
+
+ def __len__(self):
+ """Return the number of fields."""
+ return len(self.fields)
+
+ def __getitem__(self, key):
+ """Access fields by name or slice."""
+ if isinstance(key, str):
+ for field in self:
+ if field.name == key:
+ return field
+ raise KeyError('No StructField named {0}'.format(key))
+ elif isinstance(key, int):
+ try:
+ return self.fields[key]
+ except IndexError:
+ raise IndexError('StructType index out of range')
+ elif isinstance(key, slice):
+ return StructType(self.fields[key])
+ else:
+ raise TypeError('StructType keys should be strings, integers or slices')
+
def simpleString(self):
- return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields))
+ return 'struct<%s>' % (','.join(f.simpleString() for f in self))
def __repr__(self):
return ("StructType(List(%s))" %
- ",".join(str(field) for field in self.fields))
+ ",".join(str(field) for field in self))
def jsonValue(self):
return {"type": self.typeName(),
- "fields": [f.jsonValue() for f in self.fields]}
+ "fields": [f.jsonValue() for f in self]}
@classmethod
def fromJson(cls, json):