From 1e42e96ece9e35ceed9ddebef66d589016878b56 Mon Sep 17 00:00:00 2001 From: Gabe Mulley Date: Mon, 12 Jan 2015 21:44:51 -0800 Subject: [SPARK-5138][SQL] Ensure schema can be inferred from a namedtuple When attempting to infer the schema of an RDD that contains namedtuples, pyspark fails to identify the records as namedtuples, resulting in it raising an error. Example: ```python from pyspark import SparkContext from pyspark.sql import SQLContext from collections import namedtuple import os sc = SparkContext() rdd = sc.textFile(os.path.join(os.getenv('SPARK_HOME'), 'README.md')) TextLine = namedtuple('TextLine', 'line length') tuple_rdd = rdd.map(lambda l: TextLine(line=l, length=len(l))) tuple_rdd.take(5) # This works sqlc = SQLContext(sc) # The following line raises an error schema_rdd = sqlc.inferSchema(tuple_rdd) ``` The error raised is: ``` File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/worker.py", line 107, in main process() File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/worker.py", line 98, in process serializer.dump_stream(func(split_index, iterator), outfile) File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/serializers.py", line 227, in dump_stream vs = list(itertools.islice(iterator, batch)) File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/rdd.py", line 1107, in takeUpToNumLeft yield next(iterator) File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/sql.py", line 816, in convert_struct raise ValueError("unexpected tuple: %s" % obj) TypeError: not all arguments converted during string formatting ``` Author: Gabe Mulley Closes #3978 from mulby/inferschema-namedtuple and squashes the following commits: 98c61cc [Gabe Mulley] Ensure exception message is populated correctly 375d96b [Gabe Mulley] Ensure schema can be inferred from a namedtuple --- python/pyspark/sql.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) (limited to 'python') diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 0e8b398fc6..014ac1791c 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -807,14 +807,14 @@ def _create_converter(dataType): return if isinstance(obj, tuple): - if hasattr(obj, "fields"): - d = dict(zip(obj.fields, obj)) - if hasattr(obj, "__FIELDS__"): + if hasattr(obj, "_fields"): + d = dict(zip(obj._fields, obj)) + elif hasattr(obj, "__FIELDS__"): d = dict(zip(obj.__FIELDS__, obj)) elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): d = dict(obj) else: - raise ValueError("unexpected tuple: %s" % obj) + raise ValueError("unexpected tuple: %s" % str(obj)) elif isinstance(obj, dict): d = obj @@ -1327,6 +1327,16 @@ class SQLContext(object): >>> srdd = sqlCtx.inferSchema(nestedRdd2) >>> srdd.collect() [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] + + >>> from collections import namedtuple + >>> CustomRow = namedtuple('CustomRow', 'field1 field2') + >>> rdd = sc.parallelize( + ... [CustomRow(field1=1, field2="row1"), + ... CustomRow(field1=2, field2="row2"), + ... CustomRow(field1=3, field2="row3")]) + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.collect()[0] + Row(field1=1, field2=u'row1') """ if isinstance(rdd, SchemaRDD): -- cgit v1.2.3