aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r--python/pyspark/sql/tests.py29
-rw-r--r--python/pyspark/sql/types.py52
2 files changed, 77 insertions, 4 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index ffee43a94b..34f397d0ff 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -516,6 +516,35 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
df.filter(df.a.between(df.b, df.c)).collect())
+ def test_struct_type(self):
+ from pyspark.sql.types import StructType, StringType, StructField
+ struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+ struct2 = StructType([StructField("f1", StringType(), True),
+ StructField("f2", StringType(), True, None)])
+ self.assertEqual(struct1, struct2)
+
+ struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+ struct2 = StructType([StructField("f1", StringType(), True)])
+ self.assertNotEqual(struct1, struct2)
+
+ struct1 = (StructType().add(StructField("f1", StringType(), True))
+ .add(StructField("f2", StringType(), True, None)))
+ struct2 = StructType([StructField("f1", StringType(), True),
+ StructField("f2", StringType(), True, None)])
+ self.assertEqual(struct1, struct2)
+
+ struct1 = (StructType().add(StructField("f1", StringType(), True))
+ .add(StructField("f2", StringType(), True, None)))
+ struct2 = StructType([StructField("f1", StringType(), True)])
+ self.assertNotEqual(struct1, struct2)
+
+ # Catch exception raised during improper construction
+ try:
+ struct1 = StructType().add("name")
+ self.assertEqual(1, 0)
+ except ValueError:
+ self.assertEqual(1, 1)
+
def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 932686e5e4..ae9344e610 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -355,8 +355,7 @@ class StructType(DataType):
This is the data type representing a :class:`Row`.
"""
-
- def __init__(self, fields):
+ def __init__(self, fields=None):
"""
>>> struct1 = StructType([StructField("f1", StringType(), True)])
>>> struct2 = StructType([StructField("f1", StringType(), True)])
@@ -368,8 +367,53 @@ class StructType(DataType):
>>> struct1 == struct2
False
"""
- assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType"
- self.fields = fields
+ if not fields:
+ self.fields = []
+ else:
+ self.fields = fields
+ assert all(isinstance(f, StructField) for f in fields),\
+ "fields should be a list of StructField"
+
+ def add(self, field, data_type=None, nullable=True, metadata=None):
+ """
+ Construct a StructType by adding new elements to it to define the schema. The method accepts
+ either:
+ a) A single parameter which is a StructField object.
+ b) Between 2 and 4 parameters as (name, data_type, nullable (optional),
+ metadata(optional). The data_type parameter may be either a String or a DataType object
+
+ >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+ >>> struct2 = StructType([StructField("f1", StringType(), True),\
+ StructField("f2", StringType(), True, None)])
+ >>> struct1 == struct2
+ True
+ >>> struct1 = StructType().add(StructField("f1", StringType(), True))
+ >>> struct2 = StructType([StructField("f1", StringType(), True)])
+ >>> struct1 == struct2
+ True
+ >>> struct1 = StructType().add("f1", "string", True)
+ >>> struct2 = StructType([StructField("f1", StringType(), True)])
+ >>> struct1 == struct2
+ True
+
+ :param field: Either the name of the field or a StructField object
+ :param data_type: If present, the DataType of the StructField to create
+ :param nullable: Whether the field to add should be nullable (default True)
+ :param metadata: Any additional metadata (default None)
+ :return: a new updated StructType
+ """
+ if isinstance(field, StructField):
+ self.fields.append(field)
+ else:
+ if isinstance(field, str) and data_type is None:
+ raise ValueError("Must specify DataType if passing name of struct_field to create.")
+
+ if isinstance(data_type, str):
+ data_type_f = _parse_datatype_json_value(data_type)
+ else:
+ data_type_f = data_type
+ self.fields.append(StructField(field, data_type_f, nullable, metadata))
+ return self
def simpleString(self):
return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields))