From d5d647380f93f4773f9cb85ea6544892d409b5a1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 11 Sep 2015 14:15:16 -0700 Subject: [SPARK-10442] [SQL] fix string to boolean cast When we cast string to boolean in hive, it returns `true` if the length of string is > 0, and spark SQL follows this behavior. However, this behavior is very different from other SQL systems: 1. [presto](https://github.com/facebook/presto/blob/master/presto-main/src/main/java/com/facebook/presto/type/VarcharOperators.java#L89-L118) will return `true` for 't' 'true' '1', `false` for 'f' 'false' '0', throw exception for others. 2. [redshift](http://docs.aws.amazon.com/redshift/latest/dg/r_Boolean_type.html) will return `true` for 't' 'true' 'y' 'yes' '1', `false` for 'f' 'false' 'n' 'no' '0', null for others. 3. [postgresql](http://www.postgresql.org/docs/devel/static/datatype-boolean.html) will return `true` for 't' 'true' 'y' 'yes' 'on' '1', `false` for 'f' 'false' 'n' 'no' 'off' '0', throw exception for others. 4. [vertica](https://my.vertica.com/docs/5.0/HTML/Master/2983.htm) will return `true` for 't' 'true' 'y' 'yes' '1', `false` for 'f' 'false' 'n' 'no' '0', null for others. 5. [impala](http://www.cloudera.com/content/cloudera/en/documentation/cloudera-impala/latest/topics/impala_boolean.html) throw exception when try to cast string to boolean. 6. mysql, oracle, sqlserver don't have boolean type Whether we should change the cast behavior according to other SQL system or not is not decided yet, this PR is a test to see if we changed, how many compatibility tests will fail. Author: Wenchen Fan Closes #8698 from cloud-fan/string2boolean. --- .../spark/sql/catalyst/expressions/Cast.scala | 24 +++++++-- .../spark/sql/catalyst/util/StringUtils.scala | 8 +++ .../spark/sql/catalyst/expressions/CastSuite.scala | 61 ++++++++++++++-------- .../spark/sql/sources/hadoopFsRelationSuites.scala | 13 +++++ 4 files changed, 82 insertions(+), 24 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 2db954257b..f0bce388d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -22,7 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{StringUtils, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -140,7 +140,15 @@ case class Cast(child: Expression, dataType: DataType) // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, _.numBytes() != 0) + buildCast[UTF8String](_, s => { + if (StringUtils.isTrueString(s)) { + true + } else if (StringUtils.isFalseString(s)) { + false + } else { + null + } + }) case TimestampType => buildCast[Long](_, t => t != 0) case DateType => @@ -646,7 +654,17 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => - (c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;" + val stringUtils = StringUtils.getClass.getName.stripSuffix("$") + (c, evPrim, evNull) => + s""" + if ($stringUtils.isTrueString($c)) { + $evPrim = true; + } else if ($stringUtils.isFalseString($c)) { + $evPrim = false; + } else { + $evNull = true; + } + """ case TimestampType => (c, evPrim, evNull) => s"$evPrim = $c != 0;" case DateType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index 9ddfb3a0d3..c2eeb3c565 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util import java.util.regex.Pattern +import org.apache.spark.unsafe.types.UTF8String + object StringUtils { // replace the _ with .{1} exactly match 1 time of any character @@ -44,4 +46,10 @@ object StringUtils { v } } + + private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) + private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString) + + def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) + def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1ad70733ea..f4db4da764 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -503,9 +503,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast from array") { - val array = Literal.create(Seq("123", "abc", "", null), + val array = Literal.create(Seq("123", "true", "f", null), ArrayType(StringType, containsNull = true)) - val array_notNull = Literal.create(Seq("123", "abc", ""), + val array_notNull = Literal.create(Seq("123", "true", "f"), ArrayType(StringType, containsNull = false)) checkNullCast(ArrayType(StringType), ArrayType(IntegerType)) @@ -522,7 +522,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(array, ArrayType(BooleanType, containsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false, null)) + checkEvaluation(ret, Seq(null, true, false, null)) } { val ret = cast(array, ArrayType(BooleanType, containsNull = false)) @@ -541,12 +541,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false)) + checkEvaluation(ret, Seq(null, true, false)) } { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false)) + checkEvaluation(ret, Seq(null, true, false)) } { @@ -557,10 +557,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from map") { val map = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), + Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null), MapType(StringType, StringType, valueContainsNull = true)) val map_notNull = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> ""), + Map("a" -> "123", "b" -> "true", "c" -> "f"), MapType(StringType, StringType, valueContainsNull = false)) checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType)) @@ -577,7 +577,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false, "d" -> null)) } { val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false)) @@ -600,12 +600,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) @@ -630,8 +630,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val struct = Literal.create( InternalRow( UTF8String.fromString("123"), - UTF8String.fromString("abc"), - UTF8String.fromString(""), + UTF8String.fromString("true"), + UTF8String.fromString("f"), null), StructType(Seq( StructField("a", StringType, nullable = true), @@ -641,8 +641,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val struct_notNull = Literal.create( InternalRow( UTF8String.fromString("123"), - UTF8String.fromString("abc"), - UTF8String.fromString("")), + UTF8String.fromString("true"), + UTF8String.fromString("f")), StructType(Seq( StructField("a", StringType, nullable = false), StructField("b", StringType, nullable = false), @@ -672,7 +672,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("c", BooleanType, nullable = true), StructField("d", BooleanType, nullable = true)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false, null)) + checkEvaluation(ret, InternalRow(null, true, false, null)) } { val ret = cast(struct, StructType(Seq( @@ -704,7 +704,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = true)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false)) + checkEvaluation(ret, InternalRow(null, true, false)) } { val ret = cast(struct_notNull, StructType(Seq( @@ -712,7 +712,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = false)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false)) + checkEvaluation(ret, InternalRow(null, true, false)) } { @@ -731,8 +731,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("complex casting") { val complex = Literal.create( Row( - Seq("123", "abc", ""), - Map("a" ->"123", "b" -> "abc", "c" -> ""), + Seq("123", "true", "f"), + Map("a" ->"123", "b" -> "true", "c" -> "f"), Row(0)), StructType(Seq( StructField("a", @@ -755,11 +755,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ret.resolved === true) checkEvaluation(ret, Row( Seq(123, null, null), - Map("a" -> true, "b" -> true, "c" -> false), + Map("a" -> null, "b" -> true, "c" -> false), Row(0L))) } - test("case between string and interval") { + test("cast between string and interval") { import org.apache.spark.unsafe.types.CalendarInterval checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType), @@ -769,4 +769,23 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType), "interval 1 years 3 months -3 days") } + + test("cast string to boolean") { + checkCast("t", true) + checkCast("true", true) + checkCast("tRUe", true) + checkCast("y", true) + checkCast("yes", true) + checkCast("1", true) + + checkCast("f", false) + checkCast("false", false) + checkCast("FAlsE", false) + checkCast("n", false) + checkCast("no", false) + checkCast("0", false) + + checkEvaluation(cast("abc", BooleanType), null) + checkEvaluation(cast("", BooleanType), null) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 13223c6158..8ffcef8566 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -375,6 +375,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } + test("saveAsTable()/load() - partitioned table - boolean type") { + sqlContext.range(2) + .select('id, ('id % 2 === 0).as("b")) + .write.partitionBy("b").saveAsTable("t") + + withTable("t") { + checkAnswer( + sqlContext.table("t").sort('id), + Row(0, true) :: Row(1, false) :: Nil + ) + } + } + test("saveAsTable()/load() - partitioned table - Overwrite") { partitionedTestDF.write .format(dataSourceName) -- cgit v1.2.3