aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/context.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-10 19:40:12 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-10 19:40:12 -0800
commitea60284095cad43aa7ac98256576375d0e91a52a (patch)
tree35ac6e3935e1e7c731f7b9a850f2daa9640387d1 /python/pyspark/sql/context.py
parenta60aea86b4d4b716b5ec3bff776b509fe0831342 (diff)
downloadspark-ea60284095cad43aa7ac98256576375d0e91a52a.tar.gz
spark-ea60284095cad43aa7ac98256576375d0e91a52a.tar.bz2
spark-ea60284095cad43aa7ac98256576375d0e91a52a.zip
[SPARK-5704] [SQL] [PySpark] createDataFrame from RDD with columns
Deprecate inferSchema() and applySchema(), use createDataFrame() instead, which could take an optional `schema` to create an DataFrame from an RDD. The `schema` could be StructType or list of names of columns. Author: Davies Liu <davies@databricks.com> Closes #4498 from davies/create and squashes the following commits: 08469c1 [Davies Liu] remove Scala/Java API for now c80a7a9 [Davies Liu] fix hive test d1bd8f2 [Davies Liu] cleanup applySchema 9526e97 [Davies Liu] createDataFrame from RDD with columns
Diffstat (limited to 'python/pyspark/sql/context.py')
-rw-r--r--python/pyspark/sql/context.py87
1 files changed, 67 insertions, 20 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 882c0f98ea..9d29ef4839 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -25,7 +25,7 @@ from py4j.java_collections import MapConverter
from pyspark.rdd import _prepare_for_python_RDD
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
-from pyspark.sql.types import StringType, StructType, _verify_type, \
+from pyspark.sql.types import StringType, StructType, _infer_type, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
from pyspark.sql.dataframe import DataFrame
@@ -47,23 +47,11 @@ class SQLContext(object):
:param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
SQLContext in the JVM, instead we make all calls to this object.
- >>> df = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- TypeError:...
-
- >>> bad_rdd = sc.parallelize([1,2,3])
- >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- ValueError:...
-
>>> from datetime import datetime
>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
... time=datetime(2014, 8, 1, 14, 1, 5))])
- >>> df = sqlCtx.inferSchema(allTypes)
+ >>> df = sqlCtx.createDataFrame(allTypes)
>>> df.registerTempTable("allTypes")
>>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
@@ -131,6 +119,9 @@ class SQLContext(object):
def inferSchema(self, rdd, samplingRatio=None):
"""Infer and apply a schema to an RDD of L{Row}.
+ ::note:
+ Deprecated in 1.3, use :func:`createDataFrame` instead
+
When samplingRatio is specified, the schema is inferred by looking
at the types of each row in the sampled dataset. Otherwise, the
first 100 rows of the RDD are inspected. Nested collections are
@@ -199,7 +190,7 @@ class SQLContext(object):
warnings.warn("Some of types cannot be determined by the "
"first 100 rows, please try again with sampling")
else:
- if samplingRatio > 0.99:
+ if samplingRatio < 0.99:
rdd = rdd.sample(False, float(samplingRatio))
schema = rdd.map(_infer_schema).reduce(_merge_type)
@@ -211,6 +202,9 @@ class SQLContext(object):
"""
Applies the given schema to the given RDD of L{tuple} or L{list}.
+ ::note:
+ Deprecated in 1.3, use :func:`createDataFrame` instead
+
These tuples or lists can contain complex nested structures like
lists, maps or nested rows.
@@ -300,13 +294,68 @@ class SQLContext(object):
df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return DataFrame(df, self)
+ def createDataFrame(self, rdd, schema=None, samplingRatio=None):
+ """
+ Create a DataFrame from an RDD of tuple/list and an optional `schema`.
+
+ `schema` could be :class:`StructType` or a list of column names.
+
+ When `schema` is a list of column names, the type of each column
+ will be inferred from `rdd`.
+
+ When `schema` is None, it will try to infer the column name and type
+ from `rdd`, which should be an RDD of :class:`Row`, or namedtuple,
+ or dict.
+
+ If referring needed, `samplingRatio` is used to determined how many
+ rows will be used to do referring. The first row will be used if
+ `samplingRatio` is None.
+
+ :param rdd: an RDD of Row or tuple or list or dict
+ :param schema: a StructType or list of names of columns
+ :param samplingRatio: the sample ratio of rows used for inferring
+ :return: a DataFrame
+
+ >>> rdd = sc.parallelize([('Alice', 1)])
+ >>> df = sqlCtx.createDataFrame(rdd, ['name', 'age'])
+ >>> df.collect()
+ [Row(name=u'Alice', age=1)]
+
+ >>> from pyspark.sql import Row
+ >>> Person = Row('name', 'age')
+ >>> person = rdd.map(lambda r: Person(*r))
+ >>> df2 = sqlCtx.createDataFrame(person)
+ >>> df2.collect()
+ [Row(name=u'Alice', age=1)]
+
+ >>> from pyspark.sql.types import *
+ >>> schema = StructType([
+ ... StructField("name", StringType(), True),
+ ... StructField("age", IntegerType(), True)])
+ >>> df3 = sqlCtx.createDataFrame(rdd, schema)
+ >>> df3.collect()
+ [Row(name=u'Alice', age=1)]
+ """
+ if isinstance(rdd, DataFrame):
+ raise TypeError("rdd is already a DataFrame")
+
+ if isinstance(schema, StructType):
+ return self.applySchema(rdd, schema)
+ else:
+ if isinstance(schema, (list, tuple)):
+ first = rdd.first()
+ if not isinstance(first, (list, tuple)):
+ raise ValueError("each row in `rdd` should be list or tuple")
+ row_cls = Row(*schema)
+ rdd = rdd.map(lambda r: row_cls(*r))
+ return self.inferSchema(rdd, samplingRatio)
+
def registerRDDAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.
Temporary tables exist only during the lifetime of this instance of
SQLContext.
- >>> df = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(df, "table1")
"""
if (rdd.__class__ is DataFrame):
@@ -321,7 +370,6 @@ class SQLContext(object):
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
- >>> df = sqlCtx.inferSchema(rdd)
>>> df.saveAsParquetFile(parquetFile)
>>> df2 = sqlCtx.parquetFile(parquetFile)
>>> sorted(df.collect()) == sorted(df2.collect())
@@ -526,7 +574,6 @@ class SQLContext(object):
def sql(self, sqlQuery):
"""Return a L{DataFrame} representing the result of the given query.
- >>> df = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(df, "table1")
>>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> df2.collect()
@@ -537,7 +584,6 @@ class SQLContext(object):
def table(self, tableName):
"""Returns the specified table as a L{DataFrame}.
- >>> df = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(df, "table1")
>>> df2 = sqlCtx.table("table1")
>>> sorted(df.collect()) == sorted(df2.collect())
@@ -685,11 +731,12 @@ def _test():
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
globs['sqlCtx'] = sqlCtx = SQLContext(sc)
- globs['rdd'] = sc.parallelize(
+ globs['rdd'] = rdd = sc.parallelize(
[Row(field1=1, field2="row1"),
Row(field1=2, field2="row2"),
Row(field1=3, field2="row3")]
)
+ globs['df'] = sqlCtx.createDataFrame(rdd)
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'