diff options
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r-- | python/pyspark/sql/_types.py | 27 | ||||
-rw-r--r-- | python/pyspark/sql/context.py | 13 | ||||
-rw-r--r-- | python/pyspark/sql/dataframe.py | 18 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 11 |
4 files changed, 46 insertions, 23 deletions
diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/_types.py index 110d1152fb..95fb91ad43 100644 --- a/python/pyspark/sql/_types.py +++ b/python/pyspark/sql/_types.py @@ -17,6 +17,7 @@ import sys import decimal +import time import datetime import keyword import warnings @@ -30,6 +31,9 @@ if sys.version >= "3": long = int unicode = str +from py4j.protocol import register_input_converter +from py4j.java_gateway import JavaClass + __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", @@ -1237,6 +1241,29 @@ class Row(tuple): return "<Row(%s)>" % ", ".join(self) +class DateConverter(object): + def can_convert(self, obj): + return isinstance(obj, datetime.date) + + def convert(self, obj, gateway_client): + Date = JavaClass("java.sql.Date", gateway_client) + return Date.valueOf(obj.strftime("%Y-%m-%d")) + + +class DatetimeConverter(object): + def can_convert(self, obj): + return isinstance(obj, datetime.datetime) + + def convert(self, obj, gateway_client): + Timestamp = JavaClass("java.sql.Timestamp", gateway_client) + return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000) + + +# datetime is a subclass of date, we should register DatetimeConverter first +register_input_converter(DatetimeConverter()) +register_input_converter(DateConverter()) + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index acf3c11454..f6f107ca32 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -25,7 +25,6 @@ else: from itertools import imap as map from py4j.protocol import Py4JError -from py4j.java_collections import MapConverter from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer @@ -442,15 +441,13 @@ class SQLContext(object): if source is None: source = self.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") - joptions = MapConverter().convert(options, - self._sc._gateway._gateway_client) if schema is None: - df = self._ssql_ctx.load(source, joptions) + df = self._ssql_ctx.load(source, options) else: if not isinstance(schema, StructType): raise TypeError("schema should be StructType") scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.load(source, scala_datatype, joptions) + df = self._ssql_ctx.load(source, scala_datatype, options) return DataFrame(df, self) def createExternalTable(self, tableName, path=None, source=None, @@ -471,16 +468,14 @@ class SQLContext(object): if source is None: source = self.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") - joptions = MapConverter().convert(options, - self._sc._gateway._gateway_client) if schema is None: - df = self._ssql_ctx.createExternalTable(tableName, source, joptions) + df = self._ssql_ctx.createExternalTable(tableName, source, options) else: if not isinstance(schema, StructType): raise TypeError("schema should be StructType") scala_datatype = self._ssql_ctx.parseDataType(schema.json()) df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype, - joptions) + options) return DataFrame(df, self) @ignore_unicode_prefix diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 75c181c0c7..ca9bf8efb9 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -25,8 +25,6 @@ if sys.version >= '3': else: from itertools import imap as map -from py4j.java_collections import ListConverter, MapConverter - from pyspark.context import SparkContext from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer @@ -186,9 +184,7 @@ class DataFrame(object): source = self.sql_ctx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") jmode = self._java_save_mode(mode) - joptions = MapConverter().convert(options, - self.sql_ctx._sc._gateway._gateway_client) - self._jdf.saveAsTable(tableName, source, jmode, joptions) + self._jdf.saveAsTable(tableName, source, jmode, options) def save(self, path=None, source=None, mode="error", **options): """Saves the contents of the :class:`DataFrame` to a data source. @@ -211,9 +207,7 @@ class DataFrame(object): source = self.sql_ctx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") jmode = self._java_save_mode(mode) - joptions = MapConverter().convert(options, - self._sc._gateway._gateway_client) - self._jdf.save(source, jmode, joptions) + self._jdf.save(source, jmode, options) @property def schema(self): @@ -819,7 +813,6 @@ class DataFrame(object): value = float(value) if isinstance(value, dict): - value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client) return DataFrame(self._jdf.na().fill(value), self.sql_ctx) elif subset is None: return DataFrame(self._jdf.na().fill(value), self.sql_ctx) @@ -932,9 +925,7 @@ class GroupedData(object): """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): - jmap = MapConverter().convert(exprs[0], - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(jmap) + jdf = self._jdf.agg(exprs[0]) else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" @@ -1040,8 +1031,7 @@ def _to_seq(sc, cols, converter=None): """ if converter: cols = [converter(c) for c in cols] - jcols = ListConverter().convert(cols, sc._gateway._gateway_client) - return sc._jvm.PythonUtils.toSeq(jcols) + return sc._jvm.PythonUtils.toSeq(cols) def _unary_op(name, doc="unary operator"): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index aa3aa1d164..23e8428367 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -26,6 +26,7 @@ import shutil import tempfile import pickle import functools +import datetime import py4j @@ -464,6 +465,16 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(_infer_type(2**61), LongType()) self.assertEqual(_infer_type(2**71), LongType()) + def test_filter_with_datetime(self): + time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) + date = time.date() + row = Row(date=date, time=time) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(1, df.filter(df.date == date).count()) + self.assertEqual(1, df.filter(df.time == time).count()) + self.assertEqual(0, df.filter(df.date > date).count()) + self.assertEqual(0, df.filter(df.time > time).count()) + def test_dropna(self): schema = StructType([ StructField("name", StringType(), True), |