aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py28
1 files changed, 22 insertions, 6 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 66827d4885..4d7cad5a1a 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -151,6 +151,17 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(self.sqlCtx.range(-2).count(), 0)
self.assertEqual(self.sqlCtx.range(3).count(), 3)
+ def test_duplicated_column_names(self):
+ df = self.sqlCtx.createDataFrame([(1, 2)], ["c", "c"])
+ row = df.select('*').first()
+ self.assertEqual(1, row[0])
+ self.assertEqual(2, row[1])
+ self.assertEqual("Row(c=1, c=2)", str(row))
+ # Cannot access columns
+ self.assertRaises(AnalysisException, lambda: df.select(df[0]).first())
+ self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
+ self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first())
+
def test_explode(self):
from pyspark.sql.functions import explode
d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
@@ -401,6 +412,14 @@ class SQLTests(ReusedPySparkTestCase):
point = df.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
+ def test_udf_with_udt(self):
+ from pyspark.sql.tests import ExamplePoint
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ df = self.sc.parallelize([row]).toDF()
+ self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
+ udf = UserDefinedFunction(lambda p: p.y, DoubleType())
+ self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+
def test_parquet_with_udt(self):
from pyspark.sql.tests import ExamplePoint
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
@@ -693,12 +712,9 @@ class SQLTests(ReusedPySparkTestCase):
utcnow = datetime.datetime.fromtimestamp(ts, utc)
df = self.sqlCtx.createDataFrame([(day, now, utcnow)])
day1, now1, utcnow1 = df.first()
- # Pyrolite serialize java.sql.Date as datetime, will be fixed in new version
- self.assertEqual(day1.date(), day)
- # Pyrolite does not support microsecond, the error should be
- # less than 1 millisecond
- self.assertTrue(now - now1 < datetime.timedelta(0.001))
- self.assertTrue(now - utcnow1 < datetime.timedelta(0.001))
+ self.assertEqual(day1, day)
+ self.assertEqual(now, now1)
+ self.assertEqual(now, utcnow1)
def test_decimal(self):
from decimal import Decimal