aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/context.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/context.py')
-rw-r--r--python/pyspark/sql/context.py114
1 files changed, 72 insertions, 42 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 9d29ef4839..db4bcbece2 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -23,12 +23,18 @@ from itertools import imap
from py4j.protocol import Py4JError
from py4j.java_collections import MapConverter
-from pyspark.rdd import _prepare_for_python_RDD
+from pyspark.rdd import RDD, _prepare_for_python_RDD
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
-from pyspark.sql.types import StringType, StructType, _infer_type, _verify_type, \
+from pyspark.sql.types import StringType, StructType, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
from pyspark.sql.dataframe import DataFrame
+try:
+ import pandas
+ has_pandas = True
+except ImportError:
+ has_pandas = False
+
__all__ = ["SQLContext", "HiveContext"]
@@ -116,6 +122,31 @@ class SQLContext(object):
self._sc._javaAccumulator,
returnType.json())
+ def _inferSchema(self, rdd, samplingRatio=None):
+ first = rdd.first()
+ if not first:
+ raise ValueError("The first row in RDD is empty, "
+ "can not infer schema")
+ if type(first) is dict:
+ warnings.warn("Using RDD of dict to inferSchema is deprecated,"
+ "please use pyspark.sql.Row instead")
+
+ if samplingRatio is None:
+ schema = _infer_schema(first)
+ if _has_nulltype(schema):
+ for row in rdd.take(100)[1:]:
+ schema = _merge_type(schema, _infer_schema(row))
+ if not _has_nulltype(schema):
+ break
+ else:
+ raise ValueError("Some of types cannot be determined by the "
+ "first 100 rows, please try again with sampling")
+ else:
+ if samplingRatio < 0.99:
+ rdd = rdd.sample(False, float(samplingRatio))
+ schema = rdd.map(_infer_schema).reduce(_merge_type)
+ return schema
+
def inferSchema(self, rdd, samplingRatio=None):
"""Infer and apply a schema to an RDD of L{Row}.
@@ -171,29 +202,7 @@ class SQLContext(object):
if isinstance(rdd, DataFrame):
raise TypeError("Cannot apply schema to DataFrame")
- first = rdd.first()
- if not first:
- raise ValueError("The first row in RDD is empty, "
- "can not infer schema")
- if type(first) is dict:
- warnings.warn("Using RDD of dict to inferSchema is deprecated,"
- "please use pyspark.sql.Row instead")
-
- if samplingRatio is None:
- schema = _infer_schema(first)
- if _has_nulltype(schema):
- for row in rdd.take(100)[1:]:
- schema = _merge_type(schema, _infer_schema(row))
- if not _has_nulltype(schema):
- break
- else:
- warnings.warn("Some of types cannot be determined by the "
- "first 100 rows, please try again with sampling")
- else:
- if samplingRatio < 0.99:
- rdd = rdd.sample(False, float(samplingRatio))
- schema = rdd.map(_infer_schema).reduce(_merge_type)
-
+ schema = self._inferSchema(rdd, samplingRatio)
converter = _create_converter(schema)
rdd = rdd.map(converter)
return self.applySchema(rdd, schema)
@@ -274,7 +283,7 @@ class SQLContext(object):
raise TypeError("Cannot apply schema to DataFrame")
if not isinstance(schema, StructType):
- raise TypeError("schema should be StructType")
+ raise TypeError("schema should be StructType, but got %s" % schema)
# take the first few rows to verify schema
rows = rdd.take(10)
@@ -294,9 +303,9 @@ class SQLContext(object):
df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return DataFrame(df, self)
- def createDataFrame(self, rdd, schema=None, samplingRatio=None):
+ def createDataFrame(self, data, schema=None, samplingRatio=None):
"""
- Create a DataFrame from an RDD of tuple/list and an optional `schema`.
+ Create a DataFrame from an RDD of tuple/list, list or pandas.DataFrame.
`schema` could be :class:`StructType` or a list of column names.
@@ -311,12 +320,20 @@ class SQLContext(object):
rows will be used to do referring. The first row will be used if
`samplingRatio` is None.
- :param rdd: an RDD of Row or tuple or list or dict
+ :param data: an RDD of Row/tuple/list/dict, list, or pandas.DataFrame
:param schema: a StructType or list of names of columns
:param samplingRatio: the sample ratio of rows used for inferring
:return: a DataFrame
- >>> rdd = sc.parallelize([('Alice', 1)])
+ >>> l = [('Alice', 1)]
+ >>> sqlCtx.createDataFrame(l, ['name', 'age']).collect()
+ [Row(name=u'Alice', age=1)]
+
+ >>> d = [{'name': 'Alice', 'age': 1}]
+ >>> sqlCtx.createDataFrame(d).collect()
+ [Row(age=1, name=u'Alice')]
+
+ >>> rdd = sc.parallelize(l)
>>> df = sqlCtx.createDataFrame(rdd, ['name', 'age'])
>>> df.collect()
[Row(name=u'Alice', age=1)]
@@ -336,19 +353,32 @@ class SQLContext(object):
>>> df3.collect()
[Row(name=u'Alice', age=1)]
"""
- if isinstance(rdd, DataFrame):
- raise TypeError("rdd is already a DataFrame")
+ if isinstance(data, DataFrame):
+ raise TypeError("data is already a DataFrame")
- if isinstance(schema, StructType):
- return self.applySchema(rdd, schema)
- else:
- if isinstance(schema, (list, tuple)):
- first = rdd.first()
- if not isinstance(first, (list, tuple)):
- raise ValueError("each row in `rdd` should be list or tuple")
- row_cls = Row(*schema)
- rdd = rdd.map(lambda r: row_cls(*r))
- return self.inferSchema(rdd, samplingRatio)
+ if has_pandas and isinstance(data, pandas.DataFrame):
+ data = self._sc.parallelize(data.to_records(index=False))
+ if schema is None:
+ schema = list(data.columns)
+
+ if not isinstance(data, RDD):
+ try:
+ # data could be list, tuple, generator ...
+ data = self._sc.parallelize(data)
+ except Exception:
+ raise ValueError("cannot create an RDD from type: %s" % type(data))
+
+ if schema is None:
+ return self.inferSchema(data, samplingRatio)
+
+ if isinstance(schema, (list, tuple)):
+ first = data.first()
+ if not isinstance(first, (list, tuple)):
+ raise ValueError("each row in `rdd` should be list or tuple")
+ row_cls = Row(*schema)
+ schema = self._inferSchema(data.map(lambda r: row_cls(*r)), samplingRatio)
+
+ return self.applySchema(data, schema)
def registerRDDAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.