aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/_config.yml1
-rw-r--r--python/docs/pyspark.sql.rst3
-rw-r--r--python/pyspark/sql/context.py182
-rw-r--r--python/pyspark/sql/dataframe.py56
-rw-r--r--python/pyspark/sql/functions.py1
-rw-r--r--python/pyspark/sql/tests.py68
-rw-r--r--python/pyspark/sql/types.py2
7 files changed, 130 insertions, 183 deletions
diff --git a/docs/_config.yml b/docs/_config.yml
index e2db274e1f..0652927a8c 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -10,6 +10,7 @@ kramdown:
include:
- _static
+ - _modules
# These allow the documentation to be updated with newer releases
# of Spark, Scala, and Mesos.
diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst
index e03379e521..2e3f69b9a5 100644
--- a/python/docs/pyspark.sql.rst
+++ b/python/docs/pyspark.sql.rst
@@ -7,7 +7,6 @@ Module Context
.. automodule:: pyspark.sql
:members:
:undoc-members:
- :show-inheritance:
pyspark.sql.types module
@@ -15,7 +14,6 @@ pyspark.sql.types module
.. automodule:: pyspark.sql.types
:members:
:undoc-members:
- :show-inheritance:
pyspark.sql.functions module
@@ -23,4 +21,3 @@ pyspark.sql.functions module
.. automodule:: pyspark.sql.functions
:members:
:undoc-members:
- :show-inheritance:
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 125933c9d3..5d7aeb664c 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -129,6 +129,7 @@ class SQLContext(object):
>>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlCtx.sql("SELECT stringLengthString('test')").collect()
[Row(c0=u'4')]
+
>>> from pyspark.sql.types import IntegerType
>>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
@@ -197,31 +198,6 @@ class SQLContext(object):
>>> df = sqlCtx.inferSchema(rdd)
>>> df.collect()[0]
Row(field1=1, field2=u'row1')
-
- >>> NestedRow = Row("f1", "f2")
- >>> nestedRdd1 = sc.parallelize([
- ... NestedRow(array('i', [1, 2]), {"row1": 1.0}),
- ... NestedRow(array('i', [2, 3]), {"row2": 2.0})])
- >>> df = sqlCtx.inferSchema(nestedRdd1)
- >>> df.collect()
- [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})]
-
- >>> nestedRdd2 = sc.parallelize([
- ... NestedRow([[1, 2], [2, 3]], [1, 2]),
- ... NestedRow([[2, 3], [3, 4]], [2, 3])])
- >>> df = sqlCtx.inferSchema(nestedRdd2)
- >>> df.collect()
- [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
-
- >>> from collections import namedtuple
- >>> CustomRow = namedtuple('CustomRow', 'field1 field2')
- >>> rdd = sc.parallelize(
- ... [CustomRow(field1=1, field2="row1"),
- ... CustomRow(field1=2, field2="row2"),
- ... CustomRow(field1=3, field2="row3")])
- >>> df = sqlCtx.inferSchema(rdd)
- >>> df.collect()[0]
- Row(field1=1, field2=u'row1')
"""
if isinstance(rdd, DataFrame):
@@ -252,56 +228,8 @@ class SQLContext(object):
>>> schema = StructType([StructField("field1", IntegerType(), False),
... StructField("field2", StringType(), False)])
>>> df = sqlCtx.applySchema(rdd2, schema)
- >>> sqlCtx.registerDataFrameAsTable(df, "table1")
- >>> df2 = sqlCtx.sql("SELECT * from table1")
- >>> df2.collect()
- [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
-
- >>> from datetime import date, datetime
- >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
- ... date(2010, 1, 1),
- ... datetime(2010, 1, 1, 1, 1, 1),
- ... {"a": 1}, (2,), [1, 2, 3], None)])
- >>> schema = StructType([
- ... StructField("byte1", ByteType(), False),
- ... StructField("byte2", ByteType(), False),
- ... StructField("short1", ShortType(), False),
- ... StructField("short2", ShortType(), False),
- ... StructField("int1", IntegerType(), False),
- ... StructField("float1", FloatType(), False),
- ... StructField("date1", DateType(), False),
- ... StructField("time1", TimestampType(), False),
- ... StructField("map1",
- ... MapType(StringType(), IntegerType(), False), False),
- ... StructField("struct1",
- ... StructType([StructField("b", ShortType(), False)]), False),
- ... StructField("list1", ArrayType(ByteType(), False), False),
- ... StructField("null1", DoubleType(), True)])
- >>> df = sqlCtx.applySchema(rdd, schema)
- >>> results = df.map(
- ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1,
- ... x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
- >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
- (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
- datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
-
- >>> df.registerTempTable("table2")
- >>> sqlCtx.sql(
- ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
- ... "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
- ... "float1 + 1.5 as float1 FROM table2").collect()
- [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int1=2147483646, float1=2.5)]
-
- >>> from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type
- >>> rdd = sc.parallelize([(127, -32768, 1.0,
- ... datetime(2010, 1, 1, 1, 1, 1),
- ... {"a": 1}, (2,), [1, 2, 3])])
- >>> abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]"
- >>> schema = _parse_schema_abstract(abstract)
- >>> typedSchema = _infer_schema_type(rdd.first(), schema)
- >>> df = sqlCtx.applySchema(rdd, typedSchema)
>>> df.collect()
- [Row(byte1=127, short1=-32768, float1=1.0, time1=..., list1=[1, 2, 3])]
+ [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
"""
if isinstance(rdd, DataFrame):
@@ -459,46 +387,28 @@ class SQLContext(object):
>>> import tempfile, shutil
>>> jsonFile = tempfile.mkdtemp()
>>> shutil.rmtree(jsonFile)
- >>> ofn = open(jsonFile, 'w')
- >>> for json in jsonStrings:
- ... print>>ofn, json
- >>> ofn.close()
+ >>> with open(jsonFile, 'w') as f:
+ ... f.writelines(jsonStrings)
>>> df1 = sqlCtx.jsonFile(jsonFile)
- >>> sqlCtx.registerDataFrameAsTable(df1, "table1")
- >>> df2 = sqlCtx.sql(
- ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
- ... "field6 as f4 from table1")
- >>> for r in df2.collect():
- ... print r
- Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
- Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
- Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
-
- >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema)
- >>> sqlCtx.registerDataFrameAsTable(df3, "table2")
- >>> df4 = sqlCtx.sql(
- ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
- ... "field6 as f4 from table2")
- >>> for r in df4.collect():
- ... print r
- Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
- Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
- Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
+ >>> df1.printSchema()
+ root
+ |-- field1: long (nullable = true)
+ |-- field2: string (nullable = true)
+ |-- field3: struct (nullable = true)
+ | |-- field4: long (nullable = true)
>>> from pyspark.sql.types import *
>>> schema = StructType([
- ... StructField("field2", StringType(), True),
+ ... StructField("field2", StringType()),
... StructField("field3",
- ... StructType([
- ... StructField("field5",
- ... ArrayType(IntegerType(), False), True)]), False)])
- >>> df5 = sqlCtx.jsonFile(jsonFile, schema)
- >>> sqlCtx.registerDataFrameAsTable(df5, "table3")
- >>> df6 = sqlCtx.sql(
- ... "SELECT field2 AS f1, field3.field5 as f2, "
- ... "field3.field5[0] as f3 from table3")
- >>> df6.collect()
- [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
+ ... StructType([StructField("field5", ArrayType(IntegerType()))]))])
+ >>> df2 = sqlCtx.jsonFile(jsonFile, schema)
+ >>> df2.printSchema()
+ root
+ |-- field2: string (nullable = true)
+ |-- field3: struct (nullable = true)
+ | |-- field5: array (nullable = true)
+ | | |-- element: integer (containsNull = true)
"""
if schema is None:
df = self._ssql_ctx.jsonFile(path, samplingRatio)
@@ -517,48 +427,23 @@ class SQLContext(object):
determine the schema.
>>> df1 = sqlCtx.jsonRDD(json)
- >>> sqlCtx.registerDataFrameAsTable(df1, "table1")
- >>> df2 = sqlCtx.sql(
- ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
- ... "field6 as f4 from table1")
- >>> for r in df2.collect():
- ... print r
- Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
- Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
- Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
-
- >>> df3 = sqlCtx.jsonRDD(json, df1.schema)
- >>> sqlCtx.registerDataFrameAsTable(df3, "table2")
- >>> df4 = sqlCtx.sql(
- ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
- ... "field6 as f4 from table2")
- >>> for r in df4.collect():
- ... print r
- Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
- Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
- Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
+ >>> df1.first()
+ Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None)
+
+ >>> df2 = sqlCtx.jsonRDD(json, df1.schema)
+ >>> df2.first()
+ Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None)
>>> from pyspark.sql.types import *
>>> schema = StructType([
- ... StructField("field2", StringType(), True),
+ ... StructField("field2", StringType()),
... StructField("field3",
- ... StructType([
- ... StructField("field5",
- ... ArrayType(IntegerType(), False), True)]), False)])
- >>> df5 = sqlCtx.jsonRDD(json, schema)
- >>> sqlCtx.registerDataFrameAsTable(df5, "table3")
- >>> df6 = sqlCtx.sql(
- ... "SELECT field2 AS f1, field3.field5 as f2, "
- ... "field3.field5[0] as f3 from table3")
- >>> df6.collect()
- [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)]
-
- >>> sqlCtx.jsonRDD(sc.parallelize(['{}',
- ... '{"key0": {"key1": "value1"}}'])).collect()
- [Row(key0=None), Row(key0=Row(key1=u'value1'))]
- >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}',
- ... '{"key0": {"key1": "value1"}}'])).collect()
- [Row(key0=None), Row(key0=Row(key1=u'value1'))]
+ ... StructType([StructField("field5", ArrayType(IntegerType()))]))
+ ... ])
+ >>> df3 = sqlCtx.jsonRDD(json, schema)
+ >>> df3.first()
+ Row(field2=u'row1', field3=Row(field5=None))
+
"""
def func(iterator):
@@ -848,7 +733,8 @@ def _test():
globs['jsonStrings'] = jsonStrings
globs['json'] = sc.parallelize(jsonStrings)
(failure_count, test_count) = doctest.testmod(
- pyspark.sql.context, globs=globs, optionflags=doctest.ELLIPSIS)
+ pyspark.sql.context, globs=globs,
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
globs['sc'].stop()
if failure_count:
exit(-1)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 6f746d136b..6d42410020 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -96,7 +96,7 @@ class DataFrame(object):
return self._lazy_rdd
def toJSON(self, use_unicode=False):
- """Convert a DataFrame into a MappedRDD of JSON documents; one document per row.
+ """Convert a :class:`DataFrame` into a MappedRDD of JSON documents; one document per row.
>>> df.toJSON().first()
'{"age":2,"name":"Alice"}'
@@ -108,7 +108,7 @@ class DataFrame(object):
"""Save the contents as a Parquet file, preserving the schema.
Files that are written out using this method can be read back in as
- a DataFrame using the L{SQLContext.parquetFile} method.
+ a :class:`DataFrame` using the L{SQLContext.parquetFile} method.
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
@@ -139,7 +139,7 @@ class DataFrame(object):
self.registerTempTable(name)
def insertInto(self, tableName, overwrite=False):
- """Inserts the contents of this DataFrame into the specified table.
+ """Inserts the contents of this :class:`DataFrame` into the specified table.
Optionally overwriting any existing data.
"""
@@ -165,7 +165,7 @@ class DataFrame(object):
return jmode
def saveAsTable(self, tableName, source=None, mode="append", **options):
- """Saves the contents of the DataFrame to a data source as a table.
+ """Saves the contents of the :class:`DataFrame` to a data source as a table.
The data source is specified by the `source` and a set of `options`.
If `source` is not specified, the default data source configured by
@@ -174,12 +174,13 @@ class DataFrame(object):
Additionally, mode is used to specify the behavior of the saveAsTable operation when
table already exists in the data source. There are four modes:
- * append: Contents of this DataFrame are expected to be appended to existing table.
- * overwrite: Data in the existing table is expected to be overwritten by the contents of \
- this DataFrame.
+ * append: Contents of this :class:`DataFrame` are expected to be appended \
+ to existing table.
+ * overwrite: Data in the existing table is expected to be overwritten by \
+ the contents of this DataFrame.
* error: An exception is expected to be thrown.
- * ignore: The save operation is expected to not save the contents of the DataFrame and \
- to not change the existing table.
+ * ignore: The save operation is expected to not save the contents of the \
+ :class:`DataFrame` and to not change the existing table.
"""
if source is None:
source = self.sql_ctx.getConf("spark.sql.sources.default",
@@ -190,7 +191,7 @@ class DataFrame(object):
self._jdf.saveAsTable(tableName, source, jmode, joptions)
def save(self, path=None, source=None, mode="append", **options):
- """Saves the contents of the DataFrame to a data source.
+ """Saves the contents of the :class:`DataFrame` to a data source.
The data source is specified by the `source` and a set of `options`.
If `source` is not specified, the default data source configured by
@@ -199,11 +200,11 @@ class DataFrame(object):
Additionally, mode is used to specify the behavior of the save operation when
data already exists in the data source. There are four modes:
- * append: Contents of this DataFrame are expected to be appended to existing data.
+ * append: Contents of this :class:`DataFrame` are expected to be appended to existing data.
* overwrite: Existing data is expected to be overwritten by the contents of this DataFrame.
* error: An exception is expected to be thrown.
- * ignore: The save operation is expected to not save the contents of the DataFrame and \
- to not change the existing data.
+ * ignore: The save operation is expected to not save the contents of \
+ the :class:`DataFrame` and to not change the existing data.
"""
if path is not None:
options["path"] = path
@@ -217,7 +218,7 @@ class DataFrame(object):
@property
def schema(self):
- """Returns the schema of this DataFrame (represented by
+ """Returns the schema of this :class:`DataFrame` (represented by
a L{StructType}).
>>> df.schema
@@ -275,12 +276,12 @@ class DataFrame(object):
"""
Print the first 20 rows.
+ >>> df
+ DataFrame[age: int, name: string]
>>> df.show()
age name
2 Alice
5 Bob
- >>> df
- DataFrame[age: int, name: string]
"""
print self._jdf.showString().encode('utf8', 'ignore')
@@ -481,8 +482,8 @@ class DataFrame(object):
def join(self, other, joinExprs=None, joinType=None):
"""
- Join with another DataFrame, using the given join expression.
- The following performs a full outer join between `df1` and `df2`::
+ Join with another :class:`DataFrame`, using the given join expression.
+ The following performs a full outer join between `df1` and `df2`.
:param other: Right side of the join
:param joinExprs: Join expression
@@ -582,8 +583,6 @@ class DataFrame(object):
def select(self, *cols):
""" Selecting a set of expressions.
- >>> df.select().collect()
- [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.select('*').collect()
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.select('name', 'age').collect()
@@ -591,8 +590,6 @@ class DataFrame(object):
>>> df.select(df.name, (df.age + 10).alias('age')).collect()
[Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
"""
- if not cols:
- cols = ["*"]
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
@@ -612,7 +609,7 @@ class DataFrame(object):
def filter(self, condition):
""" Filtering rows using the given condition, which could be
- Column expression or string of SQL expression.
+ :class:`Column` expression or string of SQL expression.
where() is an alias for filter().
@@ -666,7 +663,7 @@ class DataFrame(object):
return self.groupBy().agg(*exprs)
def unionAll(self, other):
- """ Return a new DataFrame containing union of rows in this
+ """ Return a new :class:`DataFrame` containing union of rows in this
frame and another frame.
This is equivalent to `UNION ALL` in SQL.
@@ -919,9 +916,10 @@ class Column(object):
"""
A column in a DataFrame.
- `Column` instances can be created by::
+ :class:`Column` instances can be created by::
# 1. Select a column out of a DataFrame
+
df.colName
df["colName"]
@@ -975,7 +973,7 @@ class Column(object):
def substr(self, startPos, length):
"""
- Return a Column which is a substring of the column
+ Return a :class:`Column` which is a substring of the column
:param startPos: start position (int or Column)
:param length: length of the substring (int or Column)
@@ -996,8 +994,10 @@ class Column(object):
__getslice__ = substr
# order
- asc = _unary_op("asc")
- desc = _unary_op("desc")
+ asc = _unary_op("asc", "Returns a sort expression based on the"
+ " ascending order of the given column name.")
+ desc = _unary_op("desc", "Returns a sort expression based on the"
+ " descending order of the given column name.")
isNull = _unary_op("isNull", "True if the current expression is null.")
isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 8aa4476520..5873f09ae3 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -72,6 +72,7 @@ for _name, _doc in _functions.items():
globals()[_name] = _create_function(_name, _doc)
del _name, _doc
__all__ += _functions.keys()
+__all__.sort()
def countDistinct(col, *cols):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 39071e7e35..83899ad4b1 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -36,9 +36,9 @@ if sys.version_info[:2] <= (2, 6):
else:
import unittest
-from pyspark.sql import SQLContext, HiveContext, Column
-from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \
- UserDefinedType, DoubleType, LongType, StringType, _infer_type
+from pyspark.sql import SQLContext, HiveContext, Column, Row
+from pyspark.sql.types import *
+from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
@@ -204,6 +204,68 @@ class SQLTests(ReusedPySparkTestCase):
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
self.assertEqual(1, result.head()[0])
+ def test_infer_nested_schema(self):
+ NestedRow = Row("f1", "f2")
+ nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
+ NestedRow([2, 3], {"row2": 2.0})])
+ df = self.sqlCtx.inferSchema(nestedRdd1)
+ self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
+
+ nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
+ NestedRow([[2, 3], [3, 4]], [2, 3])])
+ df = self.sqlCtx.inferSchema(nestedRdd2)
+ self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
+
+ from collections import namedtuple
+ CustomRow = namedtuple('CustomRow', 'field1 field2')
+ rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
+ CustomRow(field1=2, field2="row2"),
+ CustomRow(field1=3, field2="row3")])
+ df = self.sqlCtx.inferSchema(rdd)
+ self.assertEquals(Row(field1=1, field2=u'row1'), df.first())
+
+ def test_apply_schema(self):
+ from datetime import date, datetime
+ rdd = self.sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
+ date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
+ {"a": 1}, (2,), [1, 2, 3], None)])
+ schema = StructType([
+ StructField("byte1", ByteType(), False),
+ StructField("byte2", ByteType(), False),
+ StructField("short1", ShortType(), False),
+ StructField("short2", ShortType(), False),
+ StructField("int1", IntegerType(), False),
+ StructField("float1", FloatType(), False),
+ StructField("date1", DateType(), False),
+ StructField("time1", TimestampType(), False),
+ StructField("map1", MapType(StringType(), IntegerType(), False), False),
+ StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
+ StructField("list1", ArrayType(ByteType(), False), False),
+ StructField("null1", DoubleType(), True)])
+ df = self.sqlCtx.applySchema(rdd, schema)
+ results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1,
+ x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
+ r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
+ datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
+ self.assertEqual(r, results.first())
+
+ df.registerTempTable("table2")
+ r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
+ "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
+ "float1 + 1.5 as float1 FROM table2").first()
+
+ self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r))
+
+ from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type
+ rdd = self.sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1),
+ {"a": 1}, (2,), [1, 2, 3])])
+ abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]"
+ schema = _parse_schema_abstract(abstract)
+ typedSchema = _infer_schema_type(rdd.first(), schema)
+ df = self.sqlCtx.applySchema(rdd, typedSchema)
+ r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3])
+ self.assertEqual(r, tuple(df.first()))
+
def test_struct_in_map(self):
d = [Row(m={Row(i=1): Row(s="")})]
df = self.sc.parallelize(d).toDF()
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index b6e41cf0b2..0f5dc2be6d 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -28,7 +28,7 @@ from operator import itemgetter
__all__ = [
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
"TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
- "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", ]
+ "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"]
class DataType(object):