aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index ffee43a94b..34f397d0ff 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -516,6 +516,35 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
df.filter(df.a.between(df.b, df.c)).collect())
+ def test_struct_type(self):
+ from pyspark.sql.types import StructType, StringType, StructField
+ struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+ struct2 = StructType([StructField("f1", StringType(), True),
+ StructField("f2", StringType(), True, None)])
+ self.assertEqual(struct1, struct2)
+
+ struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
+ struct2 = StructType([StructField("f1", StringType(), True)])
+ self.assertNotEqual(struct1, struct2)
+
+ struct1 = (StructType().add(StructField("f1", StringType(), True))
+ .add(StructField("f2", StringType(), True, None)))
+ struct2 = StructType([StructField("f1", StringType(), True),
+ StructField("f2", StringType(), True, None)])
+ self.assertEqual(struct1, struct2)
+
+ struct1 = (StructType().add(StructField("f1", StringType(), True))
+ .add(StructField("f2", StringType(), True, None)))
+ struct2 = StructType([StructField("f1", StringType(), True)])
+ self.assertNotEqual(struct1, struct2)
+
+ # Catch exception raised during improper construction
+ try:
+ struct1 = StructType().add("name")
+ self.assertEqual(1, 0)
+ except ValueError:
+ self.assertEqual(1, 1)
+
def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()