aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r--python/pyspark/sql/_types.py27
-rw-r--r--python/pyspark/sql/context.py13
-rw-r--r--python/pyspark/sql/dataframe.py18
-rw-r--r--python/pyspark/sql/tests.py11
4 files changed, 46 insertions, 23 deletions
diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/_types.py
index 110d1152fb..95fb91ad43 100644
--- a/python/pyspark/sql/_types.py
+++ b/python/pyspark/sql/_types.py
@@ -17,6 +17,7 @@
import sys
import decimal
+import time
import datetime
import keyword
import warnings
@@ -30,6 +31,9 @@ if sys.version >= "3":
long = int
unicode = str
+from py4j.protocol import register_input_converter
+from py4j.java_gateway import JavaClass
+
__all__ = [
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
"TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
@@ -1237,6 +1241,29 @@ class Row(tuple):
return "<Row(%s)>" % ", ".join(self)
+class DateConverter(object):
+ def can_convert(self, obj):
+ return isinstance(obj, datetime.date)
+
+ def convert(self, obj, gateway_client):
+ Date = JavaClass("java.sql.Date", gateway_client)
+ return Date.valueOf(obj.strftime("%Y-%m-%d"))
+
+
+class DatetimeConverter(object):
+ def can_convert(self, obj):
+ return isinstance(obj, datetime.datetime)
+
+ def convert(self, obj, gateway_client):
+ Timestamp = JavaClass("java.sql.Timestamp", gateway_client)
+ return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000)
+
+
+# datetime is a subclass of date, we should register DatetimeConverter first
+register_input_converter(DatetimeConverter())
+register_input_converter(DateConverter())
+
+
def _test():
import doctest
from pyspark.context import SparkContext
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index acf3c11454..f6f107ca32 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -25,7 +25,6 @@ else:
from itertools import imap as map
from py4j.protocol import Py4JError
-from py4j.java_collections import MapConverter
from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
@@ -442,15 +441,13 @@ class SQLContext(object):
if source is None:
source = self.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
- joptions = MapConverter().convert(options,
- self._sc._gateway._gateway_client)
if schema is None:
- df = self._ssql_ctx.load(source, joptions)
+ df = self._ssql_ctx.load(source, options)
else:
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
- df = self._ssql_ctx.load(source, scala_datatype, joptions)
+ df = self._ssql_ctx.load(source, scala_datatype, options)
return DataFrame(df, self)
def createExternalTable(self, tableName, path=None, source=None,
@@ -471,16 +468,14 @@ class SQLContext(object):
if source is None:
source = self.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
- joptions = MapConverter().convert(options,
- self._sc._gateway._gateway_client)
if schema is None:
- df = self._ssql_ctx.createExternalTable(tableName, source, joptions)
+ df = self._ssql_ctx.createExternalTable(tableName, source, options)
else:
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype,
- joptions)
+ options)
return DataFrame(df, self)
@ignore_unicode_prefix
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 75c181c0c7..ca9bf8efb9 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -25,8 +25,6 @@ if sys.version >= '3':
else:
from itertools import imap as map
-from py4j.java_collections import ListConverter, MapConverter
-
from pyspark.context import SparkContext
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
@@ -186,9 +184,7 @@ class DataFrame(object):
source = self.sql_ctx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
jmode = self._java_save_mode(mode)
- joptions = MapConverter().convert(options,
- self.sql_ctx._sc._gateway._gateway_client)
- self._jdf.saveAsTable(tableName, source, jmode, joptions)
+ self._jdf.saveAsTable(tableName, source, jmode, options)
def save(self, path=None, source=None, mode="error", **options):
"""Saves the contents of the :class:`DataFrame` to a data source.
@@ -211,9 +207,7 @@ class DataFrame(object):
source = self.sql_ctx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
jmode = self._java_save_mode(mode)
- joptions = MapConverter().convert(options,
- self._sc._gateway._gateway_client)
- self._jdf.save(source, jmode, joptions)
+ self._jdf.save(source, jmode, options)
@property
def schema(self):
@@ -819,7 +813,6 @@ class DataFrame(object):
value = float(value)
if isinstance(value, dict):
- value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client)
return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
elif subset is None:
return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
@@ -932,9 +925,7 @@ class GroupedData(object):
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
- jmap = MapConverter().convert(exprs[0],
- self.sql_ctx._sc._gateway._gateway_client)
- jdf = self._jdf.agg(jmap)
+ jdf = self._jdf.agg(exprs[0])
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
@@ -1040,8 +1031,7 @@ def _to_seq(sc, cols, converter=None):
"""
if converter:
cols = [converter(c) for c in cols]
- jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
- return sc._jvm.PythonUtils.toSeq(jcols)
+ return sc._jvm.PythonUtils.toSeq(cols)
def _unary_op(name, doc="unary operator"):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index aa3aa1d164..23e8428367 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -26,6 +26,7 @@ import shutil
import tempfile
import pickle
import functools
+import datetime
import py4j
@@ -464,6 +465,16 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(_infer_type(2**61), LongType())
self.assertEqual(_infer_type(2**71), LongType())
+ def test_filter_with_datetime(self):
+ time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000)
+ date = time.date()
+ row = Row(date=date, time=time)
+ df = self.sqlCtx.createDataFrame([row])
+ self.assertEqual(1, df.filter(df.date == date).count())
+ self.assertEqual(1, df.filter(df.time == time).count())
+ self.assertEqual(0, df.filter(df.date > date).count())
+ self.assertEqual(0, df.filter(df.time > time).count())
+
def test_dropna(self):
schema = StructType([
StructField("name", StringType(), True),