aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pylintrc2
-rw-r--r--python/pyspark/cloudpickle.py38
-rw-r--r--python/pyspark/shuffle.py2
-rw-r--r--python/pyspark/sql/context.py108
-rw-r--r--python/pyspark/sql/tests.py112
-rw-r--r--python/pyspark/sql/types.py78
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala29
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala1
9 files changed, 286 insertions, 93 deletions
diff --git a/pylintrc b/pylintrc
index 0617759603..6a675770da 100644
--- a/pylintrc
+++ b/pylintrc
@@ -84,7 +84,7 @@ enable=
# If you would like to improve the code quality of pyspark, remove any of these disabled errors
# run ./dev/lint-python and see if the errors raised by pylint can be fixed.
-disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable
+disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable,undefined-loop-variable
[REPORTS]
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
index 9ef93071d2..3b64798580 100644
--- a/python/pyspark/cloudpickle.py
+++ b/python/pyspark/cloudpickle.py
@@ -350,7 +350,26 @@ class CloudPickler(Pickler):
if new_override:
d['__new__'] = obj.__new__
- self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj)
+ self.save(_load_class)
+ self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj)
+ d.pop('__doc__', None)
+ # handle property and staticmethod
+ dd = {}
+ for k, v in d.items():
+ if isinstance(v, property):
+ k = ('property', k)
+ v = (v.fget, v.fset, v.fdel, v.__doc__)
+ elif isinstance(v, staticmethod) and hasattr(v, '__func__'):
+ k = ('staticmethod', k)
+ v = v.__func__
+ elif isinstance(v, classmethod) and hasattr(v, '__func__'):
+ k = ('classmethod', k)
+ v = v.__func__
+ dd[k] = v
+ self.save(dd)
+ self.write(pickle.TUPLE2)
+ self.write(pickle.REDUCE)
+
else:
raise pickle.PicklingError("Can't pickle %r" % obj)
@@ -708,6 +727,23 @@ def _make_skel_func(code, closures, base_globals = None):
None, None, closure)
+def _load_class(cls, d):
+ """
+ Loads additional properties into class `cls`.
+ """
+ for k, v in d.items():
+ if isinstance(k, tuple):
+ typ, k = k
+ if typ == 'property':
+ v = property(*v)
+ elif typ == 'staticmethod':
+ v = staticmethod(v)
+ elif typ == 'classmethod':
+ v = classmethod(v)
+ setattr(cls, k, v)
+ return cls
+
+
"""Constructors for 3rd party libraries
Note: These can never be renamed due to client compatibility issues"""
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 8fb71bac64..b8118bdb7c 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -606,7 +606,7 @@ class ExternalList(object):
if not os.path.exists(d):
os.makedirs(d)
p = os.path.join(d, str(id(self)))
- self._file = open(p, "wb+", 65536)
+ self._file = open(p, "w+b", 65536)
self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
os.unlink(p)
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index abb6522dde..917de24f35 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -277,6 +277,66 @@ class SQLContext(object):
return self.createDataFrame(rdd, schema)
+ def _createFromRDD(self, rdd, schema, samplingRatio):
+ """
+ Create an RDD for DataFrame from an existing RDD, returns the RDD and schema.
+ """
+ if schema is None or isinstance(schema, (list, tuple)):
+ struct = self._inferSchema(rdd, samplingRatio)
+ converter = _create_converter(struct)
+ rdd = rdd.map(converter)
+ if isinstance(schema, (list, tuple)):
+ for i, name in enumerate(schema):
+ struct.fields[i].name = name
+ struct.names[i] = name
+ schema = struct
+
+ elif isinstance(schema, StructType):
+ # take the first few rows to verify schema
+ rows = rdd.take(10)
+ for row in rows:
+ _verify_type(row, schema)
+
+ else:
+ raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
+
+ # convert python objects to sql data
+ rdd = rdd.map(schema.toInternal)
+ return rdd, schema
+
+ def _createFromLocal(self, data, schema):
+ """
+ Create an RDD for DataFrame from an list or pandas.DataFrame, returns
+ the RDD and schema.
+ """
+ if has_pandas and isinstance(data, pandas.DataFrame):
+ if schema is None:
+ schema = [str(x) for x in data.columns]
+ data = [r.tolist() for r in data.to_records(index=False)]
+
+ # make sure data could consumed multiple times
+ if not isinstance(data, list):
+ data = list(data)
+
+ if schema is None or isinstance(schema, (list, tuple)):
+ struct = self._inferSchemaFromList(data)
+ if isinstance(schema, (list, tuple)):
+ for i, name in enumerate(schema):
+ struct.fields[i].name = name
+ struct.names[i] = name
+ schema = struct
+
+ elif isinstance(schema, StructType):
+ for row in data:
+ _verify_type(row, schema)
+
+ else:
+ raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
+
+ # convert python objects to sql data
+ data = [schema.toInternal(row) for row in data]
+ return self._sc.parallelize(data), schema
+
@since(1.3)
@ignore_unicode_prefix
def createDataFrame(self, data, schema=None, samplingRatio=None):
@@ -340,49 +400,15 @@ class SQLContext(object):
if isinstance(data, DataFrame):
raise TypeError("data is already a DataFrame")
- if has_pandas and isinstance(data, pandas.DataFrame):
- if schema is None:
- schema = [str(x) for x in data.columns]
- data = [r.tolist() for r in data.to_records(index=False)]
-
- if not isinstance(data, RDD):
- if not isinstance(data, list):
- data = list(data)
- try:
- # data could be list, tuple, generator ...
- rdd = self._sc.parallelize(data)
- except Exception:
- raise TypeError("cannot create an RDD from type: %s" % type(data))
+ if isinstance(data, RDD):
+ rdd, schema = self._createFromRDD(data, schema, samplingRatio)
else:
- rdd = data
-
- if schema is None or isinstance(schema, (list, tuple)):
- if isinstance(data, RDD):
- struct = self._inferSchema(rdd, samplingRatio)
- else:
- struct = self._inferSchemaFromList(data)
- if isinstance(schema, (list, tuple)):
- for i, name in enumerate(schema):
- struct.fields[i].name = name
- schema = struct
- converter = _create_converter(schema)
- rdd = rdd.map(converter)
-
- elif isinstance(schema, StructType):
- # take the first few rows to verify schema
- rows = rdd.take(10)
- for row in rows:
- _verify_type(row, schema)
-
- else:
- raise TypeError("schema should be StructType or list or None")
-
- # convert python objects to sql data
- rdd = rdd.map(schema.toInternal)
-
+ rdd, schema = self._createFromLocal(data, schema)
jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
- df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
- return DataFrame(df, self)
+ jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
+ df = DataFrame(jdf, self)
+ df._schema = schema
+ return df
@since(1.3)
def registerDataFrameAsTable(self, df, tableName):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 5aa6135dc1..ebd3ea8db6 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -75,7 +75,7 @@ class ExamplePointUDT(UserDefinedType):
@classmethod
def module(cls):
- return 'pyspark.tests'
+ return 'pyspark.sql.tests'
@classmethod
def scalaUDT(cls):
@@ -106,10 +106,45 @@ class ExamplePoint:
return "(%s,%s)" % (self.x, self.y)
def __eq__(self, other):
- return isinstance(other, ExamplePoint) and \
+ return isinstance(other, self.__class__) and \
other.x == self.x and other.y == self.y
+class PythonOnlyUDT(UserDefinedType):
+ """
+ User-defined type (UDT) for ExamplePoint.
+ """
+
+ @classmethod
+ def sqlType(self):
+ return ArrayType(DoubleType(), False)
+
+ @classmethod
+ def module(cls):
+ return '__main__'
+
+ def serialize(self, obj):
+ return [obj.x, obj.y]
+
+ def deserialize(self, datum):
+ return PythonOnlyPoint(datum[0], datum[1])
+
+ @staticmethod
+ def foo():
+ pass
+
+ @property
+ def props(self):
+ return {}
+
+
+class PythonOnlyPoint(ExamplePoint):
+ """
+ An example class to demonstrate UDT in only Python
+ """
+ __UDT__ = PythonOnlyUDT()
+
+
class DataTypeTests(unittest.TestCase):
# regression test for SPARK-6055
def test_data_type_eq(self):
@@ -395,10 +430,39 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEqual(1, row.asDict()["l"][0].a)
self.assertEqual(1.0, row.asDict()['d']['key'].c)
+ def test_udt(self):
+ from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type
+ from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
+
+ def check_datatype(datatype):
+ pickled = pickle.loads(pickle.dumps(datatype))
+ assert datatype == pickled
+ scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json())
+ python_datatype = _parse_datatype_json_string(scala_datatype.json())
+ assert datatype == python_datatype
+
+ check_datatype(ExamplePointUDT())
+ structtype_with_udt = StructType([StructField("label", DoubleType(), False),
+ StructField("point", ExamplePointUDT(), False)])
+ check_datatype(structtype_with_udt)
+ p = ExamplePoint(1.0, 2.0)
+ self.assertEqual(_infer_type(p), ExamplePointUDT())
+ _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
+ self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT()))
+
+ check_datatype(PythonOnlyUDT())
+ structtype_with_udt = StructType([StructField("label", DoubleType(), False),
+ StructField("point", PythonOnlyUDT(), False)])
+ check_datatype(structtype_with_udt)
+ p = PythonOnlyPoint(1.0, 2.0)
+ self.assertEqual(_infer_type(p), PythonOnlyUDT())
+ _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
+ self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
+
def test_infer_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- df = self.sc.parallelize([row]).toDF()
+ df = self.sqlCtx.createDataFrame([row])
schema = df.schema
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), ExamplePointUDT)
@@ -406,36 +470,66 @@ class SQLTests(ReusedPySparkTestCase):
point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
self.assertEqual(point, ExamplePoint(1.0, 2.0))
+ row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+ df = self.sqlCtx.createDataFrame([row])
+ schema = df.schema
+ field = [f for f in schema.fields if f.name == "point"][0]
+ self.assertEqual(type(field.dataType), PythonOnlyUDT)
+ df.registerTempTable("labeled_point")
+ point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
+ self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
+
def test_apply_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = (1.0, ExamplePoint(1.0, 2.0))
- rdd = self.sc.parallelize([row])
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
- df = rdd.toDF(schema)
+ df = self.sqlCtx.createDataFrame([row], schema)
point = df.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
+ row = (1.0, PythonOnlyPoint(1.0, 2.0))
+ schema = StructType([StructField("label", DoubleType(), False),
+ StructField("point", PythonOnlyUDT(), False)])
+ df = self.sqlCtx.createDataFrame([row], schema)
+ point = df.head().point
+ self.assertEquals(point, PythonOnlyPoint(1.0, 2.0))
+
def test_udf_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- df = self.sc.parallelize([row]).toDF()
+ df = self.sqlCtx.createDataFrame([row])
self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
+ row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+ df = self.sqlCtx.createDataFrame([row])
+ self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
+ udf = UserDefinedFunction(lambda p: p.y, DoubleType())
+ self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+ udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
+ self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
+
def test_parquet_with_udt(self):
- from pyspark.sql.tests import ExamplePoint
+ from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- df0 = self.sc.parallelize([row]).toDF()
+ df0 = self.sqlCtx.createDataFrame([row])
output_dir = os.path.join(self.tempdir.name, "labeled_point")
- df0.saveAsParquetFile(output_dir)
+ df0.write.parquet(output_dir)
df1 = self.sqlCtx.parquetFile(output_dir)
point = df1.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
+ row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+ df0 = self.sqlCtx.createDataFrame([row])
+ df0.write.parquet(output_dir, mode='overwrite')
+ df1 = self.sqlCtx.parquetFile(output_dir)
+ point = df1.head().point
+ self.assertEquals(point, PythonOnlyPoint(1.0, 2.0))
+
def test_column_operators(self):
ci = self.df.key
cs = self.df.value
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 8859308d66..0976aea72c 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -22,6 +22,7 @@ import datetime
import calendar
import json
import re
+import base64
from array import array
if sys.version >= "3":
@@ -31,6 +32,8 @@ if sys.version >= "3":
from py4j.protocol import register_input_converter
from py4j.java_gateway import JavaClass
+from pyspark.serializers import CloudPickleSerializer
+
__all__ = [
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
"TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
@@ -458,7 +461,7 @@ class StructType(DataType):
self.names = [f.name for f in fields]
assert all(isinstance(f, StructField) for f in fields),\
"fields should be a list of StructField"
- self._needSerializeFields = None
+ self._needSerializeAnyField = any(f.needConversion() for f in self.fields)
def add(self, field, data_type=None, nullable=True, metadata=None):
"""
@@ -501,6 +504,7 @@ class StructType(DataType):
data_type_f = data_type
self.fields.append(StructField(field, data_type_f, nullable, metadata))
self.names.append(field)
+ self._needSerializeAnyField = any(f.needConversion() for f in self.fields)
return self
def simpleString(self):
@@ -526,10 +530,7 @@ class StructType(DataType):
if obj is None:
return
- if self._needSerializeFields is None:
- self._needSerializeFields = any(f.needConversion() for f in self.fields)
-
- if self._needSerializeFields:
+ if self._needSerializeAnyField:
if isinstance(obj, dict):
return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields))
elif isinstance(obj, (tuple, list)):
@@ -550,7 +551,10 @@ class StructType(DataType):
if isinstance(obj, Row):
# it's already converted by pickler
return obj
- values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)]
+ if self._needSerializeAnyField:
+ values = [f.fromInternal(v) for f, v in zip(self.fields, obj)]
+ else:
+ values = obj
return _create_row(self.names, values)
@@ -581,9 +585,10 @@ class UserDefinedType(DataType):
@classmethod
def scalaUDT(cls):
"""
- The class name of the paired Scala UDT.
+ The class name of the paired Scala UDT (could be '', if there
+ is no corresponding one).
"""
- raise NotImplementedError("UDT must have a paired Scala UDT.")
+ return ''
def needConversion(self):
return True
@@ -622,22 +627,37 @@ class UserDefinedType(DataType):
return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
def jsonValue(self):
- schema = {
- "type": "udt",
- "class": self.scalaUDT(),
- "pyClass": "%s.%s" % (self.module(), type(self).__name__),
- "sqlType": self.sqlType().jsonValue()
- }
+ if self.scalaUDT():
+ assert self.module() != '__main__', 'UDT in __main__ cannot work with ScalaUDT'
+ schema = {
+ "type": "udt",
+ "class": self.scalaUDT(),
+ "pyClass": "%s.%s" % (self.module(), type(self).__name__),
+ "sqlType": self.sqlType().jsonValue()
+ }
+ else:
+ ser = CloudPickleSerializer()
+ b = ser.dumps(type(self))
+ schema = {
+ "type": "udt",
+ "pyClass": "%s.%s" % (self.module(), type(self).__name__),
+ "serializedClass": base64.b64encode(b).decode('utf8'),
+ "sqlType": self.sqlType().jsonValue()
+ }
return schema
@classmethod
def fromJson(cls, json):
- pyUDT = json["pyClass"]
+ pyUDT = str(json["pyClass"])
split = pyUDT.rfind(".")
pyModule = pyUDT[:split]
pyClass = pyUDT[split+1:]
m = __import__(pyModule, globals(), locals(), [pyClass])
- UDT = getattr(m, pyClass)
+ if not hasattr(m, pyClass):
+ s = base64.b64decode(json['serializedClass'].encode('utf-8'))
+ UDT = CloudPickleSerializer().loads(s)
+ else:
+ UDT = getattr(m, pyClass)
return UDT()
def __eq__(self, other):
@@ -696,11 +716,6 @@ def _parse_datatype_json_string(json_string):
>>> complex_maptype = MapType(complex_structtype,
... complex_arraytype, False)
>>> check_datatype(complex_maptype)
-
- >>> check_datatype(ExamplePointUDT())
- >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
- ... StructField("point", ExamplePointUDT(), False)])
- >>> check_datatype(structtype_with_udt)
"""
return _parse_datatype_json_value(json.loads(json_string))
@@ -752,10 +767,6 @@ if sys.version < "3":
def _infer_type(obj):
"""Infer the DataType from obj
-
- >>> p = ExamplePoint(1.0, 2.0)
- >>> _infer_type(p)
- ExamplePointUDT
"""
if obj is None:
return NullType()
@@ -1090,11 +1101,6 @@ def _verify_type(obj, dataType):
Traceback (most recent call last):
...
ValueError:...
- >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
- >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- ValueError:...
"""
# all objects are nullable
if obj is None:
@@ -1259,18 +1265,12 @@ register_input_converter(DateConverter())
def _test():
import doctest
from pyspark.context import SparkContext
- # let doctest run in pyspark.sql.types, so DataTypes can be picklable
- import pyspark.sql.types
- from pyspark.sql import Row, SQLContext
- from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
- globs = pyspark.sql.types.__dict__.copy()
+ from pyspark.sql import SQLContext
+ globs = globals()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
globs['sqlContext'] = SQLContext(sc)
- globs['ExamplePoint'] = ExamplePoint
- globs['ExamplePointUDT'] = ExamplePointUDT
- (failure_count, test_count) = doctest.testmod(
- pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS)
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 591fb26e67..f4428c2e8b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -142,12 +142,21 @@ object DataType {
("type", JString("struct"))) =>
StructType(fields.map(parseStructField))
+ // Scala/Java UDT
case JSortedObject(
("class", JString(udtClass)),
("pyClass", _),
("sqlType", _),
("type", JString("udt"))) =>
Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]]
+
+ // Python UDT
+ case JSortedObject(
+ ("pyClass", JString(pyClass)),
+ ("serializedClass", JString(serialized)),
+ ("sqlType", v: JValue),
+ ("type", JString("udt"))) =>
+ new PythonUserDefinedType(parseDataType(v), pyClass, serialized)
}
private def parseStructField(json: JValue): StructField = json match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index e47cfb4833..4305903616 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -45,6 +45,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
/** Paired Python UDT class, if exists. */
def pyUDT: String = null
+ /** Serialized Python UDT class, if exists. */
+ def serializedPyClass: String = null
+
/**
* Convert the user type to a SQL datum
*
@@ -82,3 +85,29 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
override private[sql] def acceptsType(dataType: DataType) =
this.getClass == dataType.getClass
}
+
+/**
+ * ::DeveloperApi::
+ * The user defined type in Python.
+ *
+ * Note: This can only be accessed via Python UDF, or accessed as serialized object.
+ */
+private[sql] class PythonUserDefinedType(
+ val sqlType: DataType,
+ override val pyUDT: String,
+ override val serializedPyClass: String) extends UserDefinedType[Any] {
+
+ /* The serialization is handled by UDT class in Python */
+ override def serialize(obj: Any): Any = obj
+ override def deserialize(datam: Any): Any = datam
+
+ /* There is no Java class for Python UDT */
+ override def userClass: java.lang.Class[Any] = null
+
+ override private[sql] def jsonValue: JValue = {
+ ("type" -> "udt") ~
+ ("pyClass" -> pyUDT) ~
+ ("serializedClass" -> serializedPyClass) ~
+ ("sqlType" -> sqlType.jsonValue)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
index ec084a2996..3c38916fd7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -267,7 +267,6 @@ object EvaluatePython {
pickler.save(row.values(i))
i += 1
}
- row.values.foreach(pickler.save)
out.write(Opcodes.TUPLE)
out.write(Opcodes.REDUCE)
}