aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-03 19:29:11 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-03 19:29:11 -0800
commit04450d11548cfb25d4fb77d4a33e3a7cd4254183 (patch)
tree13b5c6fd1ac8c400cf59a51fe0b84b60c64d400f /python
parentc5912ecc7b392a13089ae735c07c2d7256de36c6 (diff)
downloadspark-04450d11548cfb25d4fb77d4a33e3a7cd4254183.tar.gz
spark-04450d11548cfb25d4fb77d4a33e3a7cd4254183.tar.bz2
spark-04450d11548cfb25d4fb77d4a33e3a7cd4254183.zip
[SPARK-4192][SQL] Internal API for Python UDT
Following #2919, this PR adds Python UDT (for internal use only) with tests under "pyspark.tests". Before `SQLContext.applySchema`, we check whether we need to convert user-type instances into SQL recognizable data. In the current implementation, a Python UDT must be paired with a Scala UDT for serialization on the JVM side. A following PR will add VectorUDT in MLlib for both Scala and Python. marmbrus jkbradley davies Author: Xiangrui Meng <meng@databricks.com> Closes #3068 from mengxr/SPARK-4192-sql and squashes the following commits: acff637 [Xiangrui Meng] merge master dba5ea7 [Xiangrui Meng] only use pyClass for Python UDT output sqlType as well 2c9d7e4 [Xiangrui Meng] move import to global setup; update needsConversion 7c4a6a9 [Xiangrui Meng] address comments 75223db [Xiangrui Meng] minor update f740379 [Xiangrui Meng] remove UDT from default imports e98d9d0 [Xiangrui Meng] fix py style 4e84fce [Xiangrui Meng] remove local hive tests and add more tests 39f19e0 [Xiangrui Meng] add tests b7f666d [Xiangrui Meng] add Python UDT
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql.py206
-rw-r--r--python/pyspark/tests.py93
2 files changed, 296 insertions, 3 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 675df084bf..d16c18bc79 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -417,6 +417,75 @@ class StructType(DataType):
return StructType([StructField.fromJson(f) for f in json["fields"]])
+class UserDefinedType(DataType):
+ """
+ :: WARN: Spark Internal Use Only ::
+ SQL User-Defined Type (UDT).
+ """
+
+ @classmethod
+ def typeName(cls):
+ return cls.__name__.lower()
+
+ @classmethod
+ def sqlType(cls):
+ """
+ Underlying SQL storage type for this UDT.
+ """
+ raise NotImplementedError("UDT must implement sqlType().")
+
+ @classmethod
+ def module(cls):
+ """
+ The Python module of the UDT.
+ """
+ raise NotImplementedError("UDT must implement module().")
+
+ @classmethod
+ def scalaUDT(cls):
+ """
+ The class name of the paired Scala UDT.
+ """
+ raise NotImplementedError("UDT must have a paired Scala UDT.")
+
+ def serialize(self, obj):
+ """
+ Converts the a user-type object into a SQL datum.
+ """
+ raise NotImplementedError("UDT must implement serialize().")
+
+ def deserialize(self, datum):
+ """
+ Converts a SQL datum into a user-type object.
+ """
+ raise NotImplementedError("UDT must implement deserialize().")
+
+ def json(self):
+ return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
+
+ def jsonValue(self):
+ schema = {
+ "type": "udt",
+ "class": self.scalaUDT(),
+ "pyClass": "%s.%s" % (self.module(), type(self).__name__),
+ "sqlType": self.sqlType().jsonValue()
+ }
+ return schema
+
+ @classmethod
+ def fromJson(cls, json):
+ pyUDT = json["pyClass"]
+ split = pyUDT.rfind(".")
+ pyModule = pyUDT[:split]
+ pyClass = pyUDT[split+1:]
+ m = __import__(pyModule, globals(), locals(), [pyClass], -1)
+ UDT = getattr(m, pyClass)
+ return UDT()
+
+ def __eq__(self, other):
+ return type(self) == type(other)
+
+
_all_primitive_types = dict((v.typeName(), v)
for v in globals().itervalues()
if type(v) is PrimitiveTypeSingleton and
@@ -469,6 +538,12 @@ def _parse_datatype_json_string(json_string):
... complex_arraytype, False)
>>> check_datatype(complex_maptype)
True
+ >>> check_datatype(ExamplePointUDT())
+ True
+ >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
+ ... StructField("point", ExamplePointUDT(), False)])
+ >>> check_datatype(structtype_with_udt)
+ True
"""
return _parse_datatype_json_value(json.loads(json_string))
@@ -488,7 +563,13 @@ def _parse_datatype_json_value(json_value):
else:
raise ValueError("Could not parse datatype: %s" % json_value)
else:
- return _all_complex_types[json_value["type"]].fromJson(json_value)
+ tpe = json_value["type"]
+ if tpe in _all_complex_types:
+ return _all_complex_types[tpe].fromJson(json_value)
+ elif tpe == 'udt':
+ return UserDefinedType.fromJson(json_value)
+ else:
+ raise ValueError("not supported type: %s" % tpe)
# Mapping Python types to Spark SQL DataType
@@ -509,7 +590,18 @@ _type_mappings = {
def _infer_type(obj):
- """Infer the DataType from obj"""
+ """Infer the DataType from obj
+
+ >>> p = ExamplePoint(1.0, 2.0)
+ >>> _infer_type(p)
+ ExamplePointUDT
+ """
+ if obj is None:
+ raise ValueError("Can not infer type for None")
+
+ if hasattr(obj, '__UDT__'):
+ return obj.__UDT__
+
dataType = _type_mappings.get(type(obj))
if dataType is not None:
return dataType()
@@ -558,6 +650,93 @@ def _infer_schema(row):
return StructType(fields)
+def _need_python_to_sql_conversion(dataType):
+ """
+ Checks whether we need python to sql conversion for the given type.
+ For now, only UDTs need this conversion.
+
+ >>> _need_python_to_sql_conversion(DoubleType())
+ False
+ >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
+ ... StructField("values", ArrayType(DoubleType(), False), False)])
+ >>> _need_python_to_sql_conversion(schema0)
+ False
+ >>> _need_python_to_sql_conversion(ExamplePointUDT())
+ True
+ >>> schema1 = ArrayType(ExamplePointUDT(), False)
+ >>> _need_python_to_sql_conversion(schema1)
+ True
+ >>> schema2 = StructType([StructField("label", DoubleType(), False),
+ ... StructField("point", ExamplePointUDT(), False)])
+ >>> _need_python_to_sql_conversion(schema2)
+ True
+ """
+ if isinstance(dataType, StructType):
+ return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
+ elif isinstance(dataType, ArrayType):
+ return _need_python_to_sql_conversion(dataType.elementType)
+ elif isinstance(dataType, MapType):
+ return _need_python_to_sql_conversion(dataType.keyType) or \
+ _need_python_to_sql_conversion(dataType.valueType)
+ elif isinstance(dataType, UserDefinedType):
+ return True
+ else:
+ return False
+
+
+def _python_to_sql_converter(dataType):
+ """
+ Returns a converter that converts a Python object into a SQL datum for the given type.
+
+ >>> conv = _python_to_sql_converter(DoubleType())
+ >>> conv(1.0)
+ 1.0
+ >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
+ >>> conv([1.0, 2.0])
+ [1.0, 2.0]
+ >>> conv = _python_to_sql_converter(ExamplePointUDT())
+ >>> conv(ExamplePoint(1.0, 2.0))
+ [1.0, 2.0]
+ >>> schema = StructType([StructField("label", DoubleType(), False),
+ ... StructField("point", ExamplePointUDT(), False)])
+ >>> conv = _python_to_sql_converter(schema)
+ >>> conv((1.0, ExamplePoint(1.0, 2.0)))
+ (1.0, [1.0, 2.0])
+ """
+ if not _need_python_to_sql_conversion(dataType):
+ return lambda x: x
+
+ if isinstance(dataType, StructType):
+ names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
+ converters = map(_python_to_sql_converter, types)
+
+ def converter(obj):
+ if isinstance(obj, dict):
+ return tuple(c(obj.get(n)) for n, c in zip(names, converters))
+ elif isinstance(obj, tuple):
+ if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"):
+ return tuple(c(v) for c, v in zip(converters, obj))
+ elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
+ d = dict(obj)
+ return tuple(c(d.get(n)) for n, c in zip(names, converters))
+ else:
+ return tuple(c(v) for c, v in zip(converters, obj))
+ else:
+ raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
+ return converter
+ elif isinstance(dataType, ArrayType):
+ element_converter = _python_to_sql_converter(dataType.elementType)
+ return lambda a: [element_converter(v) for v in a]
+ elif isinstance(dataType, MapType):
+ key_converter = _python_to_sql_converter(dataType.keyType)
+ value_converter = _python_to_sql_converter(dataType.valueType)
+ return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
+ elif isinstance(dataType, UserDefinedType):
+ return lambda obj: dataType.serialize(obj)
+ else:
+ raise ValueError("Unexpected type %r" % dataType)
+
+
def _has_nulltype(dt):
""" Return whether there is NullType in `dt` or not """
if isinstance(dt, StructType):
@@ -818,11 +997,22 @@ def _verify_type(obj, dataType):
Traceback (most recent call last):
...
ValueError:...
+ >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
+ >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
"""
# all objects are nullable
if obj is None:
return
+ if isinstance(dataType, UserDefinedType):
+ if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
+ raise ValueError("%r is not an instance of type %r" % (obj, dataType))
+ _verify_type(dataType.serialize(obj), dataType.sqlType())
+ return
+
_type = type(dataType)
assert _type in _acceptable_types, "unkown datatype: %s" % dataType
@@ -897,6 +1087,8 @@ def _has_struct_or_date(dt):
return _has_struct_or_date(dt.valueType)
elif isinstance(dt, DateType):
return True
+ elif isinstance(dt, UserDefinedType):
+ return True
return False
@@ -967,6 +1159,9 @@ def _create_cls(dataType):
elif isinstance(dataType, DateType):
return datetime.date
+ elif isinstance(dataType, UserDefinedType):
+ return lambda datum: dataType.deserialize(datum)
+
elif not isinstance(dataType, StructType):
raise Exception("unexpected data type: %s" % dataType)
@@ -1244,6 +1439,10 @@ class SQLContext(object):
for row in rows:
_verify_type(row, schema)
+ # convert python objects to sql data
+ converter = _python_to_sql_converter(schema)
+ rdd = rdd.map(converter)
+
batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
jrdd = self._pythonToJava(rdd._jrdd, batched)
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
@@ -1877,6 +2076,7 @@ def _test():
# let doctest run in pyspark.sql, so DataTypes can be picklable
import pyspark.sql
from pyspark.sql import Row, SQLContext
+ from pyspark.tests import ExamplePoint, ExamplePointUDT
globs = pyspark.sql.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
@@ -1888,6 +2088,8 @@ def _test():
Row(field1=2, field2="row2"),
Row(field1=3, field2="row3")]
)
+ globs['ExamplePoint'] = ExamplePoint
+ globs['ExamplePointUDT'] = ExamplePointUDT
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 68fd756876..e947b09468 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -49,7 +49,8 @@ from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
CloudPickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
-from pyspark.sql import SQLContext, IntegerType, Row, ArrayType
+from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
+ UserDefinedType, DoubleType
from pyspark import shuffle
_have_scipy = False
@@ -694,8 +695,65 @@ class ProfilerTests(PySparkTestCase):
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
+class ExamplePointUDT(UserDefinedType):
+ """
+ User-defined type (UDT) for ExamplePoint.
+ """
+
+ @classmethod
+ def sqlType(self):
+ return ArrayType(DoubleType(), False)
+
+ @classmethod
+ def module(cls):
+ return 'pyspark.tests'
+
+ @classmethod
+ def scalaUDT(cls):
+ return 'org.apache.spark.sql.test.ExamplePointUDT'
+
+ def serialize(self, obj):
+ return [obj.x, obj.y]
+
+ def deserialize(self, datum):
+ return ExamplePoint(datum[0], datum[1])
+
+
+class ExamplePoint:
+ """
+ An example class to demonstrate UDT in Scala, Java, and Python.
+ """
+
+ __UDT__ = ExamplePointUDT()
+
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __repr__(self):
+ return "ExamplePoint(%s,%s)" % (self.x, self.y)
+
+ def __str__(self):
+ return "(%s,%s)" % (self.x, self.y)
+
+ def __eq__(self, other):
+ return isinstance(other, ExamplePoint) and \
+ other.x == self.x and other.y == self.y
+
+
class SQLTests(ReusedPySparkTestCase):
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(cls.tempdir.name)
+
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ shutil.rmtree(cls.tempdir.name)
+
def setUp(self):
self.sqlCtx = SQLContext(self.sc)
@@ -824,6 +882,39 @@ class SQLTests(ReusedPySparkTestCase):
row = self.sqlCtx.sql("select l[0].a AS la from test").first()
self.assertEqual(1, row.asDict()["la"])
+ def test_infer_schema_with_udt(self):
+ from pyspark.tests import ExamplePoint, ExamplePointUDT
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ srdd = self.sqlCtx.inferSchema(rdd)
+ schema = srdd.schema()
+ field = [f for f in schema.fields if f.name == "point"][0]
+ self.assertEqual(type(field.dataType), ExamplePointUDT)
+ srdd.registerTempTable("labeled_point")
+ point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point
+ self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+ def test_apply_schema_with_udt(self):
+ from pyspark.tests import ExamplePoint, ExamplePointUDT
+ row = (1.0, ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ schema = StructType([StructField("label", DoubleType(), False),
+ StructField("point", ExamplePointUDT(), False)])
+ srdd = self.sqlCtx.applySchema(rdd, schema)
+ point = srdd.first().point
+ self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
+ def test_parquet_with_udt(self):
+ from pyspark.tests import ExamplePoint
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ srdd0 = self.sqlCtx.inferSchema(rdd)
+ output_dir = os.path.join(self.tempdir.name, "labeled_point")
+ srdd0.saveAsParquetFile(output_dir)
+ srdd1 = self.sqlCtx.parquetFile(output_dir)
+ point = srdd1.first().point
+ self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
class InputFormatTests(ReusedPySparkTestCase):