diff options
Diffstat (limited to 'python/pyspark/sql/types.py')
-rw-r--r-- | python/pyspark/sql/types.py | 52 |
1 files changed, 48 insertions, 4 deletions
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 932686e5e4..ae9344e610 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -355,8 +355,7 @@ class StructType(DataType): This is the data type representing a :class:`Row`. """ - - def __init__(self, fields): + def __init__(self, fields=None): """ >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct2 = StructType([StructField("f1", StringType(), True)]) @@ -368,8 +367,53 @@ class StructType(DataType): >>> struct1 == struct2 False """ - assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType" - self.fields = fields + if not fields: + self.fields = [] + else: + self.fields = fields + assert all(isinstance(f, StructField) for f in fields),\ + "fields should be a list of StructField" + + def add(self, field, data_type=None, nullable=True, metadata=None): + """ + Construct a StructType by adding new elements to it to define the schema. The method accepts + either: + a) A single parameter which is a StructField object. + b) Between 2 and 4 parameters as (name, data_type, nullable (optional), + metadata(optional). The data_type parameter may be either a String or a DataType object + + >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + >>> struct2 = StructType([StructField("f1", StringType(), True),\ + StructField("f2", StringType(), True, None)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType().add(StructField("f1", StringType(), True)) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType().add("f1", "string", True) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + + :param field: Either the name of the field or a StructField object + :param data_type: If present, the DataType of the StructField to create + :param nullable: Whether the field to add should be nullable (default True) + :param metadata: Any additional metadata (default None) + :return: a new updated StructType + """ + if isinstance(field, StructField): + self.fields.append(field) + else: + if isinstance(field, str) and data_type is None: + raise ValueError("Must specify DataType if passing name of struct_field to create.") + + if isinstance(data_type, str): + data_type_f = _parse_datatype_json_value(data_type) + else: + data_type_f = data_type + self.fields.append(StructField(field, data_type_f, nullable, metadata)) + return self def simpleString(self): return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields)) |