aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-18 14:17:04 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-18 14:17:04 -0800
commitaa8f10e82a743d59ce87348af19c0177eb618a66 (patch)
tree87fc8bfc978015fcf3d7ff9ff2aa3717b0885f28 /python/pyspark
parentf0e3b71077a6c28aba29a7a75e901a9e0911b9f0 (diff)
downloadspark-aa8f10e82a743d59ce87348af19c0177eb618a66.tar.gz
spark-aa8f10e82a743d59ce87348af19c0177eb618a66.tar.bz2
spark-aa8f10e82a743d59ce87348af19c0177eb618a66.zip
[SPARK-5722] [SQL] [PySpark] infer int as LongType
The `int` is 64-bit on 64-bit machine (very common now), we should infer it as LongType for it in Spark SQL. Also, LongType in SQL will come back as `int`. Author: Davies Liu <davies@databricks.com> Closes #4666 from davies/long and squashes the following commits: 6bc6cc4 [Davies Liu] infer int as LongType
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/sql/dataframe.py14
-rw-r--r--python/pyspark/sql/tests.py22
-rw-r--r--python/pyspark/sql/types.py8
3 files changed, 33 insertions, 11 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 52bd75bf8a..c68c97e926 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -803,7 +803,7 @@ class GroupedData(object):
>>> df.groupBy().mean('age').collect()
[Row(AVG(age#0)=3.5)]
>>> df3.groupBy().mean('age', 'height').collect()
- [Row(AVG(age#4)=3.5, AVG(height#5)=82.5)]
+ [Row(AVG(age#4L)=3.5, AVG(height#5L)=82.5)]
"""
@df_varargs_api
@@ -814,7 +814,7 @@ class GroupedData(object):
>>> df.groupBy().avg('age').collect()
[Row(AVG(age#0)=3.5)]
>>> df3.groupBy().avg('age', 'height').collect()
- [Row(AVG(age#4)=3.5, AVG(height#5)=82.5)]
+ [Row(AVG(age#4L)=3.5, AVG(height#5L)=82.5)]
"""
@df_varargs_api
@@ -825,7 +825,7 @@ class GroupedData(object):
>>> df.groupBy().max('age').collect()
[Row(MAX(age#0)=5)]
>>> df3.groupBy().max('age', 'height').collect()
- [Row(MAX(age#4)=5, MAX(height#5)=85)]
+ [Row(MAX(age#4L)=5, MAX(height#5L)=85)]
"""
@df_varargs_api
@@ -836,7 +836,7 @@ class GroupedData(object):
>>> df.groupBy().min('age').collect()
[Row(MIN(age#0)=2)]
>>> df3.groupBy().min('age', 'height').collect()
- [Row(MIN(age#4)=2, MIN(height#5)=80)]
+ [Row(MIN(age#4L)=2, MIN(height#5L)=80)]
"""
@df_varargs_api
@@ -847,7 +847,7 @@ class GroupedData(object):
>>> df.groupBy().sum('age').collect()
[Row(SUM(age#0)=7)]
>>> df3.groupBy().sum('age', 'height').collect()
- [Row(SUM(age#4)=7, SUM(height#5)=165)]
+ [Row(SUM(age#4L)=7, SUM(height#5L)=165)]
"""
@@ -1051,7 +1051,9 @@ def _test():
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
globs['sqlCtx'] = SQLContext(sc)
- globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
+ globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\
+ .toDF(StructType([StructField('age', IntegerType()),
+ StructField('name', StringType())]))
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
Row(name='Bob', age=5, height=85)]).toDF()
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 52f7e65d9c..8e1bb36598 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -38,7 +38,7 @@ else:
from pyspark.sql import SQLContext, HiveContext, Column
from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \
- UserDefinedType, DoubleType, LongType, StringType
+ UserDefinedType, DoubleType, LongType, StringType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
@@ -324,6 +324,26 @@ class SQLTests(ReusedPySparkTestCase):
pydoc.render_doc(df.foo)
pydoc.render_doc(df.take(1))
+ def test_infer_long_type(self):
+ longrow = [Row(f1='a', f2=100000000000000)]
+ df = self.sc.parallelize(longrow).toDF()
+ self.assertEqual(df.schema.fields[1].dataType, LongType())
+
+ # this saving as Parquet caused issues as well.
+ output_dir = os.path.join(self.tempdir.name, "infer_long_type")
+ df.saveAsParquetFile(output_dir)
+ df1 = self.sqlCtx.parquetFile(output_dir)
+ self.assertEquals('a', df1.first().f1)
+ self.assertEquals(100000000000000, df1.first().f2)
+
+ self.assertEqual(_infer_type(1), LongType())
+ self.assertEqual(_infer_type(2**10), LongType())
+ self.assertEqual(_infer_type(2**20), LongType())
+ self.assertEqual(_infer_type(2**31 - 1), LongType())
+ self.assertEqual(_infer_type(2**31), LongType())
+ self.assertEqual(_infer_type(2**61), LongType())
+ self.assertEqual(_infer_type(2**71), LongType())
+
class HiveContextSQLTests(ReusedPySparkTestCase):
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 40bd7e54a9..9409c6f9f6 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -583,7 +583,7 @@ def _parse_datatype_json_value(json_value):
_type_mappings = {
type(None): NullType,
bool: BooleanType,
- int: IntegerType,
+ int: LongType,
long: LongType,
float: DoubleType,
str: StringType,
@@ -933,11 +933,11 @@ def _infer_schema_type(obj, dataType):
>>> schema = _parse_schema_abstract("a b c d")
>>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
>>> _infer_schema_type(row, schema)
- StructType...IntegerType...DoubleType...StringType...DateType...
+ StructType...LongType...DoubleType...StringType...DateType...
>>> row = [[1], {"key": (1, 2.0)}]
>>> schema = _parse_schema_abstract("a[] b{c d}")
>>> _infer_schema_type(row, schema)
- StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType...
+ StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
"""
if dataType is None:
return _infer_type(obj)
@@ -992,7 +992,7 @@ def _verify_type(obj, dataType):
>>> _verify_type(None, StructType([]))
>>> _verify_type("", StringType())
- >>> _verify_type(0, IntegerType())
+ >>> _verify_type(0, LongType())
>>> _verify_type(range(3), ArrayType(ShortType()))
>>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):