aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/tests.py32
-rw-r--r--python/pyspark/sql/types.py27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala3
3 files changed, 51 insertions, 11 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index a6fce50c76..b5fbb7d098 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -26,6 +26,7 @@ import shutil
import tempfile
import pickle
import functools
+import time
import datetime
import py4j
@@ -47,6 +48,20 @@ from pyspark.sql.functions import UserDefinedFunction
from pyspark.sql.window import Window
+class UTC(datetime.tzinfo):
+ """UTC"""
+ ZERO = datetime.timedelta(0)
+
+ def utcoffset(self, dt):
+ return self.ZERO
+
+ def tzname(self, dt):
+ return "UTC"
+
+ def dst(self, dt):
+ return self.ZERO
+
+
class ExamplePointUDT(UserDefinedType):
"""
User-defined type (UDT) for ExamplePoint.
@@ -588,6 +603,23 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(0, df.filter(df.date > date).count())
self.assertEqual(0, df.filter(df.time > time).count())
+ def test_time_with_timezone(self):
+ day = datetime.date.today()
+ now = datetime.datetime.now()
+ ts = time.mktime(now.timetuple()) + now.microsecond / 1e6
+ # class in __main__ is not serializable
+ from pyspark.sql.tests import UTC
+ utc = UTC()
+ utcnow = datetime.datetime.fromtimestamp(ts, utc)
+ df = self.sqlCtx.createDataFrame([(day, now, utcnow)])
+ day1, now1, utcnow1 = df.first()
+ # Pyrolite serialize java.sql.Date as datetime, will be fixed in new version
+ self.assertEqual(day1.date(), day)
+ # Pyrolite does not support microsecond, the error should be
+ # less than 1 millisecond
+ self.assertTrue(now - now1 < datetime.timedelta(0.001))
+ self.assertTrue(now - utcnow1 < datetime.timedelta(0.001))
+
def test_dropna(self):
schema = StructType([
StructField("name", StringType(), True),
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 8f286b631f..23d9adb0da 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -655,12 +655,15 @@ def _need_python_to_sql_conversion(dataType):
_need_python_to_sql_conversion(dataType.valueType)
elif isinstance(dataType, UserDefinedType):
return True
- elif isinstance(dataType, TimestampType):
+ elif isinstance(dataType, (DateType, TimestampType)):
return True
else:
return False
+EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
+
+
def _python_to_sql_converter(dataType):
"""
Returns a converter that converts a Python object into a SQL datum for the given type.
@@ -698,26 +701,32 @@ def _python_to_sql_converter(dataType):
return tuple(c(d.get(n)) for n, c in zip(names, converters))
else:
return tuple(c(v) for c, v in zip(converters, obj))
- else:
+ elif obj is not None:
raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
return converter
elif isinstance(dataType, ArrayType):
element_converter = _python_to_sql_converter(dataType.elementType)
- return lambda a: [element_converter(v) for v in a]
+ return lambda a: a and [element_converter(v) for v in a]
elif isinstance(dataType, MapType):
key_converter = _python_to_sql_converter(dataType.keyType)
value_converter = _python_to_sql_converter(dataType.valueType)
- return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
+ return lambda m: m and dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
+
elif isinstance(dataType, UserDefinedType):
- return lambda obj: dataType.serialize(obj)
+ return lambda obj: obj and dataType.serialize(obj)
+
+ elif isinstance(dataType, DateType):
+ return lambda d: d and d.toordinal() - EPOCH_ORDINAL
+
elif isinstance(dataType, TimestampType):
def to_posix_timstamp(dt):
- if dt.tzinfo is None:
- return int(time.mktime(dt.timetuple()) * 1e7 + dt.microsecond * 10)
- else:
- return int(calendar.timegm(dt.utctimetuple()) * 1e7 + dt.microsecond * 10)
+ if dt:
+ seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
+ else time.mktime(dt.timetuple()))
+ return int(seconds * 1e7 + dt.microsecond * 10)
return to_posix_timstamp
+
else:
raise ValueError("Unexpected type %r" % dataType)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 955b478a48..b1333ec09a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -28,8 +28,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.python.{PythonBroadcast, PythonRDD}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions.Row
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Row, _}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule