aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-24 20:51:55 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-24 20:51:55 -0800
commitd641fbb39c90b1d734cc55396ca43d7e98788975 (patch)
treed9741f7c08d6ba288b224e8c2af60fc1bdb445f3
parent769e092bdc51582372093f76dbaece27149cc4ea (diff)
downloadspark-d641fbb39c90b1d734cc55396ca43d7e98788975.tar.gz
spark-d641fbb39c90b1d734cc55396ca43d7e98788975.tar.bz2
spark-d641fbb39c90b1d734cc55396ca43d7e98788975.zip
[SPARK-5994] [SQL] Python DataFrame documentation fixes
select empty should NOT be the same as select. make sure selectExpr is behaving the same. join param documentation link to source doesn't work in jekyll generated file cross reference of columns (i.e. enabling linking) show(): move df example before df.show() move tests in SQLContext out of docstring otherwise doc is too long Column.desc and .asc doesn't have any documentation in documentation, sort functions.*) Author: Davies Liu <davies@databricks.com> Closes #4756 from davies/df_docs and squashes the following commits: f30502c [Davies Liu] fix doc 32f0d46 [Davies Liu] fix DataFrame docs
-rw-r--r--docs/_config.yml1
-rw-r--r--python/docs/pyspark.sql.rst3
-rw-r--r--python/pyspark/sql/context.py182
-rw-r--r--python/pyspark/sql/dataframe.py56
-rw-r--r--python/pyspark/sql/functions.py1
-rw-r--r--python/pyspark/sql/tests.py68
-rw-r--r--python/pyspark/sql/types.py2
7 files changed, 130 insertions, 183 deletions
diff --git a/docs/_config.yml b/docs/_config.yml
index e2db274e1f..0652927a8c 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -10,6 +10,7 @@ kramdown:
include:
- _static
+ - _modules
# These allow the documentation to be updated with newer releases
# of Spark, Scala, and Mesos.
diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst
index e03379e521..2e3f69b9a5 100644
--- a/python/docs/pyspark.sql.rst
+++ b/python/docs/pyspark.sql.rst
@@ -7,7 +7,6 @@ Module Context
.. automodule:: pyspark.sql
:members:
:undoc-members:
- :show-inheritance:
pyspark.sql.types module
@@ -15,7 +14,6 @@ pyspark.sql.types module
.. automodule:: pyspark.sql.types
:members:
:undoc-members:
- :show-inheritance:
pyspark.sql.functions module
@@ -23,4 +21,3 @@ pyspark.sql.functions module
.. automodule:: pyspark.sql.functions
:members:
:undoc-members:
- :show-inheritance:
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 125933c9d3..5d7aeb664c 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -129,6 +129,7 @@ class SQLContext(object):
>>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlCtx.sql("SELECT stringLengthString('test')").collect()
[Row(c0=u'4')]
+
>>> from pyspark.sql.types import IntegerType
>>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
@@ -197,31 +198,6 @@ class SQLContext(object):
>>> df = sqlCtx.inferSchema(rdd)
>>> df.collect()[0]
Row(field1=1, field2=u'row1')
-
- >>> NestedRow = Row("f1", "f2")
- >>> nestedRdd1 = sc.parallelize([
- ... NestedRow(array('i', [1, 2]), {"row1": 1.0}),
- ... NestedRow(array('i', [2, 3]), {"row2": 2.0})])
- >>> df = sqlCtx.inferSchema(nestedRdd1)
- >>> df.collect()
- [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})]
-
- >>> nestedRdd2 = sc.parallelize([
- ... NestedRow([[1, 2], [2, 3]], [1, 2]),
- ... NestedRow([[2, 3], [3, 4]], [2, 3])])
- >>> df = sqlCtx.inferSchema(nestedRdd2)
- >>> df.collect()
- [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
-
- >>> from collections import namedtuple
- >>> CustomRow = namedtuple('CustomRow', 'field1 field2')
- >>> rdd = sc.parallelize(
- ... [CustomRow(field1=1, field2="row1"),
- ... CustomRow(field1=2, field2="row2"),
- ... CustomRow(field1=3, field2="row3")])
- >>> df = sqlCtx.inferSchema(rdd)
- >>> df.collect()[0]
- Row(field1=1, field2=u'row1')
"""
if isinstance(rdd, DataFrame):
@@ -252,56 +228,8 @@ class SQLContext(object):
>>> schema = StructType([StructField("field1", IntegerType(), False),
... StructField("field2", StringType(), False)])
>>> df = sqlCtx.applySchema(rdd2, schema)
- >>> sqlCtx.registerDataFrameAsTable(df, "table1")
- >>> df2 = sqlCtx.sql("SELECT * from table1")
- >>> df2.collect()
- [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
-
- >>> from datetime import date, datetime
- >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
- ... date(2010, 1, 1),
- ... datetime(2010, 1, 1, 1, 1, 1),
- ... {"a": 1}, (2,), [1, 2, 3], None)])
- >>> schema = StructType([
- ... StructField("byte1", ByteType(), False),
- ... StructField("byte2", ByteType(), False),
- ... StructField("short1", ShortType(), False),
- ... StructField("short2", ShortType(), False),
- ... StructField("int1", IntegerType(), False),
- ... StructField("float1", FloatType(), False),
- ... StructField("date1", DateType(), False),
- ... StructField("time1", TimestampType(), False),
- ... StructField("map1",
- ... MapType(StringType(), IntegerType(), False), False),
- ... StructField("struct1",
- ... StructType([StructField("b", ShortType(), False)]), False),
- ... StructField("list1", ArrayType(ByteType(), False), False),
- ... StructField("null1", DoubleType(), True)])
- >>> df = sqlCtx.applySchema(rdd, schema)
- >>> results = df.map(
- ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1,
- ... x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
- >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
- (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
- datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
-
- >>> df.registerTempTable("table2")
- >>> sqlCtx.sql(
- ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
- ... "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
- ... "float1 + 1.5 as float1 FROM table2").collect()
- [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int1=2147483646, float1=2.5)]
-
- >>> from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type
- >>> rdd = sc.parallelize([(127, -32768, 1.0,
- ... datetime(2010, 1, 1, 1, 1, 1),
- ... {"a": 1}, (2,), [1, 2, 3])])
- >>> abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]"
- >>> schema = _parse_schema_abstract(abstract)
- >>> typedSchema = _infer_schema_type(rdd.first(), schema)
- >>> df = sqlCtx.applySchema(rdd, typedSchema)
>>> df.collect()
- [Row(byte1=127, short1=-32768, float1=1.0, time1=..., list1=[1, 2, 3])]
+ [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
"""
if isinstance(rdd, DataFrame):
@@ -459,46 +387,28 @@ class SQLContext(object):
>>> import tempfile, shutil
>>> jsonFile = tempfile.mkdtemp()
>>> shutil.rmtree(jsonFile)
- >>> ofn = open(jsonFile, 'w')
- >>> for json in jsonStrings:
- ... print>>ofn, json
- >>> ofn.close()
+ >>> with open(jsonFile, 'w') as f:
+ ... f.writelines(jsonStrings)
>>> df1 = sqlCtx.jsonFile(jsonFile)
- >>> sqlCtx.registerDataFrameAsTable(df1, "table1")
- >>> df2 = sqlCtx.sql(
- ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
- ... "field6 as f4 from table1")
- >>> for r in df2.collect():
- ... print r
- Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
- Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
- Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
-
- >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema)
- >>> sqlCtx.registerDataFrameAsTable(df3, "table2")
- >>> df4 = sqlCtx.sql(
- ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
- ... "field6 as f4 from table2")
- >>> for r in df4.collect():
- ... print r
- Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
- Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
- Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
+ >>> df1.printSchema()
+ root
+ |-- field1: long (nullable = true)
+ |-- field2: string (nullable = true)
+ |-- field3: struct (nullable = true)
+ | |-- field4: long (nullable = true)
>>> from pyspark.sql.types import *
>>> schema = StructType([
- ... StructField("field2", StringType(), True),
+ ... StructField("field2", StringType()),
... StructField("field3",
- ... StructType([
- ... StructField("field5",
- ... ArrayType(IntegerType(), False), True)]), False)])
- >>> df5 = sqlCtx.jsonFile(jsonFile, schema)
- >>> sqlCtx.registerDataFrameAsTable(df5, "table3")
- >>> df6 = sqlCtx.sql(
- ... "SELECT field2 AS f1, field3.field5 as f2, "
- ... "field3.field5[0] as f3 from table3")
- >>> df6.collect()
- [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
+ ... StructType([StructField("field5", ArrayType(IntegerType()))]))])
+ >>> df2 = sqlCtx.jsonFile(jsonFile, schema)
+ >>> df2.printSchema()
+ root
+ |-- field2: string (nullable = true)
+ |-- field3: struct (nullable = true)
+ | |-- field5: array (nullable = true)
+ | | |-- element: integer (containsNull = true)
"""
if schema is None:
df = self._ssql_ctx.jsonFile(path, samplingRatio)
@@ -517,48 +427,23 @@ class SQLContext(object):
determine the schema.
>>> df1 = sqlCtx.jsonRDD(json)
- >>> sqlCtx.registerDataFrameAsTable(df1, "table1")
- >>> df2 = sqlCtx.sql(
- ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
- ... "field6 as f4 from table1")
- >>> for r in df2.collect():
- ... print r
- Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
- Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
- Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
-
- >>> df3 = sqlCtx.jsonRDD(json, df1.schema)
- >>> sqlCtx.registerDataFrameAsTable(df3, "table2")
- >>> df4 = sqlCtx.sql(
- ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
- ... "field6 as f4 from table2")
- >>> for r in df4.collect():
- ... print r
- Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
- Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
- Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
+ >>> df1.first()
+ Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None)
+
+ >>> df2 = sqlCtx.jsonRDD(json, df1.schema)
+ >>> df2.first()
+ Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None)
>>> from pyspark.sql.types import *
>>> schema = StructType([
- ... StructField("field2", StringType(), True),
+ ... StructField("field2", StringType()),
... StructField("field3",
- ... StructType([
- ... StructField("field5",
- ... ArrayType(IntegerType(), False), True)]), False)])
- >>> df5 = sqlCtx.jsonRDD(json, schema)
- >>> sqlCtx.registerDataFrameAsTable(df5, "table3")
- >>> df6 = sqlCtx.sql(
- ... "SELECT field2 AS f1, field3.field5 as f2, "
- ... "field3.field5[0] as f3 from table3")
- >>> df6.collect()
- [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)]
-
- >>> sqlCtx.jsonRDD(sc.parallelize(['{}',
- ... '{"key0": {"key1": "value1"}}'])).collect()
- [Row(key0=None), Row(key0=Row(key1=u'value1'))]
- >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}',
- ... '{"key0": {"key1": "value1"}}'])).collect()
- [Row(key0=None), Row(key0=Row(key1=u'value1'))]
+ ... StructType([StructField("field5", ArrayType(IntegerType()))]))
+ ... ])
+ >>> df3 = sqlCtx.jsonRDD(json, schema)
+ >>> df3.first()
+ Row(field2=u'row1', field3=Row(field5=None))
+
"""
def func(iterator):
@@ -848,7 +733,8 @@ def _test():
globs['jsonStrings'] = jsonStrings
globs['json'] = sc.parallelize(jsonStrings)
(failure_count, test_count) = doctest.testmod(
- pyspark.sql.context, globs=globs, optionflags=doctest.ELLIPSIS)
+ pyspark.sql.context, globs=globs,
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
globs['sc'].stop()
if failure_count:
exit(-1)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 6f746d136b..6d42410020 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -96,7 +96,7 @@ class DataFrame(object):
return self._lazy_rdd
def toJSON(self, use_unicode=False):
- """Convert a DataFrame into a MappedRDD of JSON documents; one document per row.
+ """Convert a :class:`DataFrame` into a MappedRDD of JSON documents; one document per row.
>>> df.toJSON().first()
'{"age":2,"name":"Alice"}'
@@ -108,7 +108,7 @@ class DataFrame(object):
"""Save the contents as a Parquet file, preserving the schema.
Files that are written out using this method can be read back in as
- a DataFrame using the L{SQLContext.parquetFile} method.
+ a :class:`DataFrame` using the L{SQLContext.parquetFile} method.
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
@@ -139,7 +139,7 @@ class DataFrame(object):
self.registerTempTable(name)
def insertInto(self, tableName, overwrite=False):
- """Inserts the contents of this DataFrame into the specified table.
+ """Inserts the contents of this :class:`DataFrame` into the specified table.
Optionally overwriting any existing data.
"""
@@ -165,7 +165,7 @@ class DataFrame(object):
return jmode
def saveAsTable(self, tableName, source=None, mode="append", **options):
- """Saves the contents of the DataFrame to a data source as a table.
+ """Saves the contents of the :class:`DataFrame` to a data source as a table.
The data source is specified by the `source` and a set of `options`.
If `source` is not specified, the default data source configured by
@@ -174,12 +174,13 @@ class DataFrame(object):
Additionally, mode is used to specify the behavior of the saveAsTable operation when
table already exists in the data source. There are four modes:
- * append: Contents of this DataFrame are expected to be appended to existing table.
- * overwrite: Data in the existing table is expected to be overwritten by the contents of \
- this DataFrame.
+ * append: Contents of this :class:`DataFrame` are expected to be appended \
+ to existing table.
+ * overwrite: Data in the existing table is expected to be overwritten by \
+ the contents of this DataFrame.
* error: An exception is expected to be thrown.
- * ignore: The save operation is expected to not save the contents of the DataFrame and \
- to not change the existing table.
+ * ignore: The save operation is expected to not save the contents of the \
+ :class:`DataFrame` and to not change the existing table.
"""
if source is None:
source = self.sql_ctx.getConf("spark.sql.sources.default",
@@ -190,7 +191,7 @@ class DataFrame(object):
self._jdf.saveAsTable(tableName, source, jmode, joptions)
def save(self, path=None, source=None, mode="append", **options):
- """Saves the contents of the DataFrame to a data source.
+ """Saves the contents of the :class:`DataFrame` to a data source.
The data source is specified by the `source` and a set of `options`.
If `source` is not specified, the default data source configured by
@@ -199,11 +200,11 @@ class DataFrame(object):
Additionally, mode is used to specify the behavior of the save operation when
data already exists in the data source. There are four modes:
- * append: Contents of this DataFrame are expected to be appended to existing data.
+ * append: Contents of this :class:`DataFrame` are expected to be appended to existing data.
* overwrite: Existing data is expected to be overwritten by the contents of this DataFrame.
* error: An exception is expected to be thrown.
- * ignore: The save operation is expected to not save the contents of the DataFrame and \
- to not change the existing data.
+ * ignore: The save operation is expected to not save the contents of \
+ the :class:`DataFrame` and to not change the existing data.
"""
if path is not None:
options["path"] = path
@@ -217,7 +218,7 @@ class DataFrame(object):
@property
def schema(self):
- """Returns the schema of this DataFrame (represented by
+ """Returns the schema of this :class:`DataFrame` (represented by
a L{StructType}).
>>> df.schema
@@ -275,12 +276,12 @@ class DataFrame(object):
"""
Print the first 20 rows.
+ >>> df
+ DataFrame[age: int, name: string]
>>> df.show()
age name
2 Alice
5 Bob
- >>> df
- DataFrame[age: int, name: string]
"""
print self._jdf.showString().encode('utf8', 'ignore')
@@ -481,8 +482,8 @@ class DataFrame(object):
def join(self, other, joinExprs=None, joinType=None):
"""
- Join with another DataFrame, using the given join expression.
- The following performs a full outer join between `df1` and `df2`::
+ Join with another :class:`DataFrame`, using the given join expression.
+ The following performs a full outer join between `df1` and `df2`.
:param other: Right side of the join
:param joinExprs: Join expression
@@ -582,8 +583,6 @@ class DataFrame(object):
def select(self, *cols):
""" Selecting a set of expressions.
- >>> df.select().collect()
- [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.select('*').collect()
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.select('name', 'age').collect()
@@ -591,8 +590,6 @@ class DataFrame(object):
>>> df.select(df.name, (df.age + 10).alias('age')).collect()
[Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
"""
- if not cols:
- cols = ["*"]
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
@@ -612,7 +609,7 @@ class DataFrame(object):
def filter(self, condition):
""" Filtering rows using the given condition, which could be
- Column expression or string of SQL expression.
+ :class:`Column` expression or string of SQL expression.
where() is an alias for filter().
@@ -666,7 +663,7 @@ class DataFrame(object):
return self.groupBy().agg(*exprs)
def unionAll(self, other):
- """ Return a new DataFrame containing union of rows in this
+ """ Return a new :class:`DataFrame` containing union of rows in this
frame and another frame.
This is equivalent to `UNION ALL` in SQL.
@@ -919,9 +916,10 @@ class Column(object):
"""
A column in a DataFrame.
- `Column` instances can be created by::
+ :class:`Column` instances can be created by::
# 1. Select a column out of a DataFrame
+
df.colName
df["colName"]
@@ -975,7 +973,7 @@ class Column(object):
def substr(self, startPos, length):
"""
- Return a Column which is a substring of the column
+ Return a :class:`Column` which is a substring of the column
:param startPos: start position (int or Column)
:param length: length of the substring (int or Column)
@@ -996,8 +994,10 @@ class Column(object):
__getslice__ = substr
# order
- asc = _unary_op("asc")
- desc = _unary_op("desc")
+ asc = _unary_op("asc", "Returns a sort expression based on the"
+ " ascending order of the given column name.")
+ desc = _unary_op("desc", "Returns a sort expression based on the"
+ " descending order of the given column name.")
isNull = _unary_op("isNull", "True if the current expression is null.")
isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 8aa4476520..5873f09ae3 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -72,6 +72,7 @@ for _name, _doc in _functions.items():
globals()[_name] = _create_function(_name, _doc)
del _name, _doc
__all__ += _functions.keys()
+__all__.sort()
def countDistinct(col, *cols):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 39071e7e35..83899ad4b1 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -36,9 +36,9 @@ if sys.version_info[:2] <= (2, 6):
else:
import unittest
-from pyspark.sql import SQLContext, HiveContext, Column
-from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \
- UserDefinedType, DoubleType, LongType, StringType, _infer_type
+from pyspark.sql import SQLContext, HiveContext, Column, Row
+from pyspark.sql.types import *
+from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
@@ -204,6 +204,68 @@ class SQLTests(ReusedPySparkTestCase):
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
self.assertEqual(1, result.head()[0])
+ def test_infer_nested_schema(self):
+ NestedRow = Row("f1", "f2")
+ nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
+ NestedRow([2, 3], {"row2": 2.0})])
+ df = self.sqlCtx.inferSchema(nestedRdd1)
+ self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
+
+ nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
+ NestedRow([[2, 3], [3, 4]], [2, 3])])
+ df = self.sqlCtx.inferSchema(nestedRdd2)
+ self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
+
+ from collections import namedtuple
+ CustomRow = namedtuple('CustomRow', 'field1 field2')
+ rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
+ CustomRow(field1=2, field2="row2"),
+ CustomRow(field1=3, field2="row3")])
+ df = self.sqlCtx.inferSchema(rdd)
+ self.assertEquals(Row(field1=1, field2=u'row1'), df.first())
+
+ def test_apply_schema(self):
+ from datetime import date, datetime
+ rdd = self.sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
+ date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
+ {"a": 1}, (2,), [1, 2, 3], None)])
+ schema = StructType([
+ StructField("byte1", ByteType(), False),
+ StructField("byte2", ByteType(), False),
+ StructField("short1", ShortType(), False),
+ StructField("short2", ShortType(), False),
+ StructField("int1", IntegerType(), False),
+ StructField("float1", FloatType(), False),
+ StructField("date1", DateType(), False),
+ StructField("time1", TimestampType(), False),
+ StructField("map1", MapType(StringType(), IntegerType(), False), False),
+ StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
+ StructField("list1", ArrayType(ByteType(), False), False),
+ StructField("null1", DoubleType(), True)])
+ df = self.sqlCtx.applySchema(rdd, schema)
+ results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1,
+ x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
+ r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
+ datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
+ self.assertEqual(r, results.first())
+
+ df.registerTempTable("table2")
+ r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
+ "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
+ "float1 + 1.5 as float1 FROM table2").first()
+
+ self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r))
+
+ from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type
+ rdd = self.sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1),
+ {"a": 1}, (2,), [1, 2, 3])])
+ abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]"
+ schema = _parse_schema_abstract(abstract)
+ typedSchema = _infer_schema_type(rdd.first(), schema)
+ df = self.sqlCtx.applySchema(rdd, typedSchema)
+ r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3])
+ self.assertEqual(r, tuple(df.first()))
+
def test_struct_in_map(self):
d = [Row(m={Row(i=1): Row(s="")})]
df = self.sc.parallelize(d).toDF()
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index b6e41cf0b2..0f5dc2be6d 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -28,7 +28,7 @@ from operator import itemgetter
__all__ = [
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
"TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
- "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", ]
+ "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"]
class DataType(object):