diff options
author | Yin Huai <huai@cse.ohio-state.edu> | 2014-08-05 18:56:10 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-08-05 18:56:10 -0700 |
commit | 69ec678d3aaeb6ece85e5e82353bf083bfc83667 (patch) | |
tree | 5bb73c4135b8971c8bbaf8811b7efaf6e68449d2 | |
parent | d0ae3f3912104a8227cd964c42e229a297a48ffa (diff) | |
download | spark-69ec678d3aaeb6ece85e5e82353bf083bfc83667.tar.gz spark-69ec678d3aaeb6ece85e5e82353bf083bfc83667.tar.bz2 spark-69ec678d3aaeb6ece85e5e82353bf083bfc83667.zip |
[SPARK-2854][SQL] Finalize _acceptable_types in pyspark.sql
This PR aims to finalize accepted data value types in Python RDDs provided to Python `applySchema`.
JIRA: https://issues.apache.org/jira/browse/SPARK-2854
Author: Yin Huai <huai@cse.ohio-state.edu>
Closes #1793 from yhuai/SPARK-2854 and squashes the following commits:
32f0708 [Yin Huai] LongType only accepts long values.
c2b23dd [Yin Huai] Do data type conversions based on the specified Spark SQL data type.
-rw-r--r-- | python/pyspark/sql.py | 29 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala | 3 |
2 files changed, 23 insertions, 9 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 1a829c6faf..f1093701dd 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -672,12 +672,12 @@ _acceptable_types = { ByteType: (int, long), ShortType: (int, long), IntegerType: (int, long), - LongType: (int, long), + LongType: (long,), FloatType: (float,), DoubleType: (float,), DecimalType: (decimal.Decimal,), StringType: (str, unicode), - TimestampType: (datetime.datetime, datetime.time, datetime.date), + TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), StructType: (tuple, list), @@ -1042,12 +1042,15 @@ class SQLContext: [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] >>> from datetime import datetime - >>> rdd = sc.parallelize([(127, -32768, 1.0, + >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), ... {"a": 1}, (2,), [1, 2, 3], None)]) >>> schema = StructType([ - ... StructField("byte", ByteType(), False), - ... StructField("short", ShortType(), False), + ... StructField("byte1", ByteType(), False), + ... StructField("byte2", ByteType(), False), + ... StructField("short1", ShortType(), False), + ... StructField("short2", ShortType(), False), + ... StructField("int", IntegerType(), False), ... StructField("float", FloatType(), False), ... StructField("time", TimestampType(), False), ... StructField("map", @@ -1056,11 +1059,19 @@ class SQLContext: ... StructType([StructField("b", ShortType(), False)]), False), ... StructField("list", ArrayType(ByteType(), False), False), ... StructField("null", DoubleType(), True)]) - >>> srdd = sqlCtx.applySchema(rdd, schema).map( - ... lambda x: (x.byte, x.short, x.float, x.time, + >>> srdd = sqlCtx.applySchema(rdd, schema) + >>> results = srdd.map( + ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time, ... x.map["a"], x.struct.b, x.list, x.null)) - >>> srdd.collect()[0] - (127, -32768, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + >>> results.collect()[0] + (127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + + >>> srdd.registerTempTable("table2") + >>> sqlCtx.sql( + ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + + ... "float + 1.1 as float FROM table2").collect() + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1)] >>> rdd = sc.parallelize([(127, -32768, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ecd5fbaa0b..71d338d21d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -491,7 +491,10 @@ class SQLContext(@transient val sparkContext: SparkContext) new java.sql.Timestamp(c.getTime().getTime()) case (c: Int, ByteType) => c.toByte + case (c: Long, ByteType) => c.toByte case (c: Int, ShortType) => c.toShort + case (c: Long, ShortType) => c.toShort + case (c: Long, IntegerType) => c.toInt case (c: Double, FloatType) => c.toFloat case (c, StringType) if !c.isInstanceOf[String] => c.toString |