aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Lian <lian.cs.zju@gmail.com>2014-08-13 16:27:50 -0700
committerMichael Armbrust <michael@databricks.com>2014-08-13 16:27:50 -0700
commitbdc7a1a4749301f8d18617c130c7766684aa8789 (patch)
treeeadad0f16afbf77e9f605e2df55e941825f60e09 /sql
parent7ecb867c4cd6916b6cb12f2ece1a4c88591ad5b5 (diff)
downloadspark-bdc7a1a4749301f8d18617c130c7766684aa8789.tar.gz
spark-bdc7a1a4749301f8d18617c130c7766684aa8789.tar.bz2
spark-bdc7a1a4749301f8d18617c130c7766684aa8789.zip
[SPARK-3004][SQL] Added null checking when retrieving row set
JIRA issue: [SPARK-3004](https://issues.apache.org/jira/browse/SPARK-3004) HiveThriftServer2 throws exception when the result set contains `NULL`. Should check `isNullAt` in `SparkSQLOperationManager.getNextRowSet`. Note that simply using `row.addColumnValue(null)` doesn't work, since Hive set the column type of a null `ColumnValue` to String by default. Author: Cheng Lian <lian.cs.zju@gmail.com> Closes #1920 from liancheng/spark-3004 and squashes the following commits: 1b1db1c [Cheng Lian] Adding NULL column values in the Hive way 2217722 [Cheng Lian] Fixed SPARK-3004: added null checking when retrieving row set
Diffstat (limited to 'sql')
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala93
-rw-r--r--sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt10
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala26
3 files changed, 96 insertions, 33 deletions
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
index f192f490ac..9338e8121b 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -73,35 +73,10 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage
var curCol = 0
while (curCol < sparkRow.length) {
- dataTypes(curCol) match {
- case StringType =>
- row.addString(sparkRow(curCol).asInstanceOf[String])
- case IntegerType =>
- row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol)))
- case BooleanType =>
- row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol)))
- case DoubleType =>
- row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol)))
- case FloatType =>
- row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol)))
- case DecimalType =>
- val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal
- row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal)))
- case LongType =>
- row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol)))
- case ByteType =>
- row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol)))
- case ShortType =>
- row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol)))
- case TimestampType =>
- row.addColumnValue(
- ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp]))
- case BinaryType | _: ArrayType | _: StructType | _: MapType =>
- val hiveString = result
- .queryExecution
- .asInstanceOf[HiveContext#QueryExecution]
- .toHiveString((sparkRow.get(curCol), dataTypes(curCol)))
- row.addColumnValue(ColumnValue.stringValue(hiveString))
+ if (sparkRow.isNullAt(curCol)) {
+ addNullColumnValue(sparkRow, row, curCol)
+ } else {
+ addNonNullColumnValue(sparkRow, row, curCol)
}
curCol += 1
}
@@ -112,6 +87,66 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage
}
}
+ def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) {
+ dataTypes(ordinal) match {
+ case StringType =>
+ to.addString(from(ordinal).asInstanceOf[String])
+ case IntegerType =>
+ to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal)))
+ case BooleanType =>
+ to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal)))
+ case DoubleType =>
+ to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal)))
+ case FloatType =>
+ to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal)))
+ case DecimalType =>
+ val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal
+ to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal)))
+ case LongType =>
+ to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal)))
+ case ByteType =>
+ to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal)))
+ case ShortType =>
+ to.addColumnValue(ColumnValue.intValue(from.getShort(ordinal)))
+ case TimestampType =>
+ to.addColumnValue(
+ ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp]))
+ case BinaryType | _: ArrayType | _: StructType | _: MapType =>
+ val hiveString = result
+ .queryExecution
+ .asInstanceOf[HiveContext#QueryExecution]
+ .toHiveString((from.get(ordinal), dataTypes(ordinal)))
+ to.addColumnValue(ColumnValue.stringValue(hiveString))
+ }
+ }
+
+ def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) {
+ dataTypes(ordinal) match {
+ case StringType =>
+ to.addString(null)
+ case IntegerType =>
+ to.addColumnValue(ColumnValue.intValue(null))
+ case BooleanType =>
+ to.addColumnValue(ColumnValue.booleanValue(null))
+ case DoubleType =>
+ to.addColumnValue(ColumnValue.doubleValue(null))
+ case FloatType =>
+ to.addColumnValue(ColumnValue.floatValue(null))
+ case DecimalType =>
+ to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal))
+ case LongType =>
+ to.addColumnValue(ColumnValue.longValue(null))
+ case ByteType =>
+ to.addColumnValue(ColumnValue.byteValue(null))
+ case ShortType =>
+ to.addColumnValue(ColumnValue.intValue(null))
+ case TimestampType =>
+ to.addColumnValue(ColumnValue.timestampValue(null))
+ case BinaryType | _: ArrayType | _: StructType | _: MapType =>
+ to.addColumnValue(ColumnValue.stringValue(null: String))
+ }
+ }
+
def getResultSetSchema: TableSchema = {
logWarning(s"Result Schema: ${result.queryExecution.analyzed.output}")
if (result.queryExecution.analyzed.output.size == 0) {
diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt
new file mode 100644
index 0000000000..ae08c640e6
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt
@@ -0,0 +1,10 @@
+238val_238
+
+311val_311
+val_27
+val_165
+val_409
+255val_255
+278val_278
+98val_98
+val_484
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
index 78bffa2607..aedef6ce1f 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
@@ -113,22 +113,40 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt
val stmt = createStatement()
stmt.execute("DROP TABLE IF EXISTS test")
stmt.execute("DROP TABLE IF EXISTS test_cached")
- stmt.execute("CREATE TABLE test(key int, val string)")
+ stmt.execute("CREATE TABLE test(key INT, val STRING)")
stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test")
- stmt.execute("CREATE TABLE test_cached as select * from test limit 4")
+ stmt.execute("CREATE TABLE test_cached AS SELECT * FROM test LIMIT 4")
stmt.execute("CACHE TABLE test_cached")
- var rs = stmt.executeQuery("select count(*) from test")
+ var rs = stmt.executeQuery("SELECT COUNT(*) FROM test")
rs.next()
assert(rs.getInt(1) === 5)
- rs = stmt.executeQuery("select count(*) from test_cached")
+ rs = stmt.executeQuery("SELECT COUNT(*) FROM test_cached")
rs.next()
assert(rs.getInt(1) === 4)
stmt.close()
}
+ test("SPARK-3004 regression: result set containing NULL") {
+ Thread.sleep(5 * 1000)
+ val dataFilePath = getDataFile("data/files/small_kv_with_null.txt")
+ val stmt = createStatement()
+ stmt.execute("DROP TABLE IF EXISTS test_null")
+ stmt.execute("CREATE TABLE test_null(key INT, val STRING)")
+ stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_null")
+
+ val rs = stmt.executeQuery("SELECT * FROM test_null WHERE key IS NULL")
+ var count = 0
+ while (rs.next()) {
+ count += 1
+ }
+ assert(count === 5)
+
+ stmt.close()
+ }
+
def getConnection: Connection = {
val connectURI = s"jdbc:hive2://localhost:$PORT/"
DriverManager.getConnection(connectURI, System.getProperty("user.name"), "")