From ab9128fb7ec7ca479dc91e7cc7c775e8d071eafa Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 21 Apr 2015 00:08:18 -0700 Subject: [SPARK-6949] [SQL] [PySpark] Support Date/Timestamp in Column expression This PR enable auto_convert in JavaGateway, then we could register a converter for a given types, for example, date and datetime. There are two bugs related to auto_convert, see [1] and [2], we workaround it in this PR. [1] https://github.com/bartdag/py4j/issues/160 [2] https://github.com/bartdag/py4j/issues/161 cc rxin JoshRosen Author: Davies Liu Closes #5570 from davies/py4j_date and squashes the following commits: eb4fa53 [Davies Liu] fix tests in python 3 d17d634 [Davies Liu] rollback changes in mllib 2e7566d [Davies Liu] convert tuple into ArrayList ceb3779 [Davies Liu] Update rdd.py 3c373f3 [Davies Liu] support date and datetime by auto_convert cb094ff [Davies Liu] enable auto convert --- python/pyspark/context.py | 6 +----- python/pyspark/java_gateway.py | 15 ++++++++++++++- python/pyspark/rdd.py | 3 +++ python/pyspark/sql/_types.py | 27 +++++++++++++++++++++++++++ python/pyspark/sql/context.py | 13 ++++--------- python/pyspark/sql/dataframe.py | 18 ++++-------------- python/pyspark/sql/tests.py | 11 +++++++++++ python/pyspark/streaming/context.py | 11 +++-------- python/pyspark/streaming/kafka.py | 7 ++----- python/pyspark/streaming/tests.py | 6 +----- 10 files changed, 70 insertions(+), 47 deletions(-) (limited to 'python') diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6a743ac8bd..b006120eb2 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -23,8 +23,6 @@ import sys from threading import Lock from tempfile import NamedTemporaryFile -from py4j.java_collections import ListConverter - from pyspark import accumulators from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast @@ -643,7 +641,6 @@ class SparkContext(object): rdds = [x._reserialize() for x in rdds] first = rdds[0]._jrdd rest = [x._jrdd for x in rdds[1:]] - rest = ListConverter().convert(rest, self._gateway._gateway_client) return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer) def broadcast(self, value): @@ -846,13 +843,12 @@ class SparkContext(object): """ if partitions is None: partitions = range(rdd._jrdd.partitions().size()) - javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client) # Implementation note: This is implemented as a mapPartitions followed # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, + port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions, allowLocal) return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 45bc38f7e6..3cee4ea6e3 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -17,17 +17,30 @@ import atexit import os +import sys import select import signal import shlex import socket import platform from subprocess import Popen, PIPE + +if sys.version >= '3': + xrange = range + from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from py4j.java_collections import ListConverter from pyspark.serializers import read_int +# patching ListConverter, or it will convert bytearray into Java ArrayList +def can_convert_list(self, obj): + return isinstance(obj, (list, tuple, xrange)) + +ListConverter.can_convert = can_convert_list + + def launch_gateway(): if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) @@ -92,7 +105,7 @@ def launch_gateway(): atexit.register(killChild) # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False) + gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d9cdbb666f..d254deb527 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2267,6 +2267,9 @@ def _prepare_for_python_RDD(sc, command, obj=None): # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) + # There is a bug in py4j.java_gateway.JavaClass with auto_convert + # https://github.com/bartdag/py4j/issues/161 + # TODO: use auto_convert once py4j fix the bug broadcast_vars = ListConverter().convert( [x._jbroadcast for x in sc._pickled_broadcast_vars], sc._gateway._gateway_client) diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/_types.py index 110d1152fb..95fb91ad43 100644 --- a/python/pyspark/sql/_types.py +++ b/python/pyspark/sql/_types.py @@ -17,6 +17,7 @@ import sys import decimal +import time import datetime import keyword import warnings @@ -30,6 +31,9 @@ if sys.version >= "3": long = int unicode = str +from py4j.protocol import register_input_converter +from py4j.java_gateway import JavaClass + __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", @@ -1237,6 +1241,29 @@ class Row(tuple): return "" % ", ".join(self) +class DateConverter(object): + def can_convert(self, obj): + return isinstance(obj, datetime.date) + + def convert(self, obj, gateway_client): + Date = JavaClass("java.sql.Date", gateway_client) + return Date.valueOf(obj.strftime("%Y-%m-%d")) + + +class DatetimeConverter(object): + def can_convert(self, obj): + return isinstance(obj, datetime.datetime) + + def convert(self, obj, gateway_client): + Timestamp = JavaClass("java.sql.Timestamp", gateway_client) + return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000) + + +# datetime is a subclass of date, we should register DatetimeConverter first +register_input_converter(DatetimeConverter()) +register_input_converter(DateConverter()) + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index acf3c11454..f6f107ca32 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -25,7 +25,6 @@ else: from itertools import imap as map from py4j.protocol import Py4JError -from py4j.java_collections import MapConverter from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer @@ -442,15 +441,13 @@ class SQLContext(object): if source is None: source = self.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") - joptions = MapConverter().convert(options, - self._sc._gateway._gateway_client) if schema is None: - df = self._ssql_ctx.load(source, joptions) + df = self._ssql_ctx.load(source, options) else: if not isinstance(schema, StructType): raise TypeError("schema should be StructType") scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.load(source, scala_datatype, joptions) + df = self._ssql_ctx.load(source, scala_datatype, options) return DataFrame(df, self) def createExternalTable(self, tableName, path=None, source=None, @@ -471,16 +468,14 @@ class SQLContext(object): if source is None: source = self.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") - joptions = MapConverter().convert(options, - self._sc._gateway._gateway_client) if schema is None: - df = self._ssql_ctx.createExternalTable(tableName, source, joptions) + df = self._ssql_ctx.createExternalTable(tableName, source, options) else: if not isinstance(schema, StructType): raise TypeError("schema should be StructType") scala_datatype = self._ssql_ctx.parseDataType(schema.json()) df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype, - joptions) + options) return DataFrame(df, self) @ignore_unicode_prefix diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 75c181c0c7..ca9bf8efb9 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -25,8 +25,6 @@ if sys.version >= '3': else: from itertools import imap as map -from py4j.java_collections import ListConverter, MapConverter - from pyspark.context import SparkContext from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer @@ -186,9 +184,7 @@ class DataFrame(object): source = self.sql_ctx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") jmode = self._java_save_mode(mode) - joptions = MapConverter().convert(options, - self.sql_ctx._sc._gateway._gateway_client) - self._jdf.saveAsTable(tableName, source, jmode, joptions) + self._jdf.saveAsTable(tableName, source, jmode, options) def save(self, path=None, source=None, mode="error", **options): """Saves the contents of the :class:`DataFrame` to a data source. @@ -211,9 +207,7 @@ class DataFrame(object): source = self.sql_ctx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") jmode = self._java_save_mode(mode) - joptions = MapConverter().convert(options, - self._sc._gateway._gateway_client) - self._jdf.save(source, jmode, joptions) + self._jdf.save(source, jmode, options) @property def schema(self): @@ -819,7 +813,6 @@ class DataFrame(object): value = float(value) if isinstance(value, dict): - value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client) return DataFrame(self._jdf.na().fill(value), self.sql_ctx) elif subset is None: return DataFrame(self._jdf.na().fill(value), self.sql_ctx) @@ -932,9 +925,7 @@ class GroupedData(object): """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): - jmap = MapConverter().convert(exprs[0], - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(jmap) + jdf = self._jdf.agg(exprs[0]) else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" @@ -1040,8 +1031,7 @@ def _to_seq(sc, cols, converter=None): """ if converter: cols = [converter(c) for c in cols] - jcols = ListConverter().convert(cols, sc._gateway._gateway_client) - return sc._jvm.PythonUtils.toSeq(jcols) + return sc._jvm.PythonUtils.toSeq(cols) def _unary_op(name, doc="unary operator"): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index aa3aa1d164..23e8428367 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -26,6 +26,7 @@ import shutil import tempfile import pickle import functools +import datetime import py4j @@ -464,6 +465,16 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(_infer_type(2**61), LongType()) self.assertEqual(_infer_type(2**71), LongType()) + def test_filter_with_datetime(self): + time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) + date = time.date() + row = Row(date=date, time=time) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(1, df.filter(df.date == date).count()) + self.assertEqual(1, df.filter(df.time == time).count()) + self.assertEqual(0, df.filter(df.date > date).count()) + self.assertEqual(0, df.filter(df.time > time).count()) + def test_dropna(self): schema = StructType([ StructField("name", StringType(), True), diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 4590c58839..ac5ba69e8d 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -20,7 +20,6 @@ from __future__ import print_function import os import sys -from py4j.java_collections import ListConverter from py4j.java_gateway import java_import, JavaObject from pyspark import RDD, SparkConf @@ -305,9 +304,7 @@ class StreamingContext(object): rdds = [self._sc.parallelize(input) for input in rdds] self._check_serializers(rdds) - jrdds = ListConverter().convert([r._jrdd for r in rdds], - SparkContext._gateway._gateway_client) - queue = self._jvm.PythonDStream.toRDDQueue(jrdds) + queue = self._jvm.PythonDStream.toRDDQueue([r._jrdd for r in rdds]) if default: default = default._reserialize(rdds[0]._jrdd_deserializer) jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd) @@ -322,8 +319,7 @@ class StreamingContext(object): the transform function parameter will be the same as the order of corresponding DStreams in the list. """ - jdstreams = ListConverter().convert([d._jdstream for d in dstreams], - SparkContext._gateway._gateway_client) + jdstreams = [d._jdstream for d in dstreams] # change the final serializer to sc.serializer func = TransformFunction(self._sc, lambda t, *rdds: transformFunc(rdds).map(lambda x: x), @@ -346,6 +342,5 @@ class StreamingContext(object): if len(set(s._slideDuration for s in dstreams)) > 1: raise ValueError("All DStreams should have same slide duration") first = dstreams[0] - jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]], - SparkContext._gateway._gateway_client) + jrest = [d._jdstream for d in dstreams[1:]] return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 7a7b6e1d9a..8d610d6569 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -15,8 +15,7 @@ # limitations under the License. # -from py4j.java_collections import MapConverter -from py4j.java_gateway import java_import, Py4JError, Py4JJavaError +from py4j.java_gateway import Py4JJavaError from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer @@ -57,8 +56,6 @@ class KafkaUtils(object): }) if not isinstance(topics, dict): raise TypeError("topics should be dict") - jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client) - jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) try: @@ -66,7 +63,7 @@ class KafkaUtils(object): helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") helper = helperClass.newInstance() - jstream = helper.createStream(ssc._jssc, jparam, jtopics, jlevel) + jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel) except Py4JJavaError as e: # TODO: use --jar once it also work on driver if 'ClassNotFoundException' in str(e.java_exception): diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 06d2215437..33f958a601 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -24,8 +24,6 @@ import tempfile import struct from functools import reduce -from py4j.java_collections import MapConverter - from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import KafkaUtils @@ -581,11 +579,9 @@ class KafkaStreamTests(PySparkStreamingTestCase): """Test the Python Kafka stream API.""" topic = "topic1" sendData = {"a": 3, "b": 5, "c": 10} - jSendData = MapConverter().convert(sendData, - self.ssc.sparkContext._gateway._gateway_client) self._kafkaTestUtils.createTopic(topic) - self._kafkaTestUtils.sendMessages(topic, jSendData) + self._kafkaTestUtils.sendMessages(topic, sendData) stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), "test-streaming-consumer", {topic: 1}, -- cgit v1.2.3