aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/readwriter.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/readwriter.py')
-rw-r--r--python/pyspark/sql/readwriter.py20
1 files changed, 14 insertions, 6 deletions
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index a2771daabe..0b20022b14 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -130,11 +130,9 @@ class DataFrameReader(object):
self.schema(schema)
self.options(**options)
if path is not None:
- if type(path) == list:
- return self._df(
- self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
- else:
- return self._df(self._jreader.load(path))
+ if type(path) != list:
+ path = [path]
+ return self._df(self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
else:
return self._df(self._jreader.load())
@@ -179,7 +177,17 @@ class DataFrameReader(object):
elif type(path) == list:
return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
elif isinstance(path, RDD):
- return self._df(self._jreader.json(path._jrdd))
+ def func(iterator):
+ for x in iterator:
+ if not isinstance(x, basestring):
+ x = unicode(x)
+ if isinstance(x, unicode):
+ x = x.encode("utf-8")
+ yield x
+ keyed = path.mapPartitions(func)
+ keyed._bypass_serializer = True
+ jrdd = keyed._jrdd.map(self._sqlContext._jvm.BytesToString())
+ return self._df(self._jreader.json(jrdd))
else:
raise TypeError("path can be only string or RDD")