aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-09-11 14:15:16 -0700
committerYin Huai <yhuai@databricks.com>2015-09-11 14:15:16 -0700
commitd5d647380f93f4773f9cb85ea6544892d409b5a1 (patch)
tree4d4fa02d42e7787e7e692695c1400b7c758800e2 /sql
parentc373866774c082885a50daaf7c83f3a14b0cd714 (diff)
downloadspark-d5d647380f93f4773f9cb85ea6544892d409b5a1.tar.gz
spark-d5d647380f93f4773f9cb85ea6544892d409b5a1.tar.bz2
spark-d5d647380f93f4773f9cb85ea6544892d409b5a1.zip
[SPARK-10442] [SQL] fix string to boolean cast
When we cast string to boolean in hive, it returns `true` if the length of string is > 0, and spark SQL follows this behavior. However, this behavior is very different from other SQL systems: 1. [presto](https://github.com/facebook/presto/blob/master/presto-main/src/main/java/com/facebook/presto/type/VarcharOperators.java#L89-L118) will return `true` for 't' 'true' '1', `false` for 'f' 'false' '0', throw exception for others. 2. [redshift](http://docs.aws.amazon.com/redshift/latest/dg/r_Boolean_type.html) will return `true` for 't' 'true' 'y' 'yes' '1', `false` for 'f' 'false' 'n' 'no' '0', null for others. 3. [postgresql](http://www.postgresql.org/docs/devel/static/datatype-boolean.html) will return `true` for 't' 'true' 'y' 'yes' 'on' '1', `false` for 'f' 'false' 'n' 'no' 'off' '0', throw exception for others. 4. [vertica](https://my.vertica.com/docs/5.0/HTML/Master/2983.htm) will return `true` for 't' 'true' 'y' 'yes' '1', `false` for 'f' 'false' 'n' 'no' '0', null for others. 5. [impala](http://www.cloudera.com/content/cloudera/en/documentation/cloudera-impala/latest/topics/impala_boolean.html) throw exception when try to cast string to boolean. 6. mysql, oracle, sqlserver don't have boolean type Whether we should change the cast behavior according to other SQL system or not is not decided yet, this PR is a test to see if we changed, how many compatibility tests will fail. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #8698 from cloud-fan/string2boolean.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala61
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala13
4 files changed, 82 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 2db954257b..f0bce388d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -22,7 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.{StringUtils, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -140,7 +140,15 @@ case class Cast(child: Expression, dataType: DataType)
// UDFToBoolean
private[this] def castToBoolean(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, _.numBytes() != 0)
+ buildCast[UTF8String](_, s => {
+ if (StringUtils.isTrueString(s)) {
+ true
+ } else if (StringUtils.isFalseString(s)) {
+ false
+ } else {
+ null
+ }
+ })
case TimestampType =>
buildCast[Long](_, t => t != 0)
case DateType =>
@@ -646,7 +654,17 @@ case class Cast(child: Expression, dataType: DataType)
private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
case StringType =>
- (c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;"
+ val stringUtils = StringUtils.getClass.getName.stripSuffix("$")
+ (c, evPrim, evNull) =>
+ s"""
+ if ($stringUtils.isTrueString($c)) {
+ $evPrim = true;
+ } else if ($stringUtils.isFalseString($c)) {
+ $evPrim = false;
+ } else {
+ $evNull = true;
+ }
+ """
case TimestampType =>
(c, evPrim, evNull) => s"$evPrim = $c != 0;"
case DateType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
index 9ddfb3a0d3..c2eeb3c565 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util
import java.util.regex.Pattern
+import org.apache.spark.unsafe.types.UTF8String
+
object StringUtils {
// replace the _ with .{1} exactly match 1 time of any character
@@ -44,4 +46,10 @@ object StringUtils {
v
}
}
+
+ private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString)
+ private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString)
+
+ def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
+ def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 1ad70733ea..f4db4da764 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -503,9 +503,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("cast from array") {
- val array = Literal.create(Seq("123", "abc", "", null),
+ val array = Literal.create(Seq("123", "true", "f", null),
ArrayType(StringType, containsNull = true))
- val array_notNull = Literal.create(Seq("123", "abc", ""),
+ val array_notNull = Literal.create(Seq("123", "true", "f"),
ArrayType(StringType, containsNull = false))
checkNullCast(ArrayType(StringType), ArrayType(IntegerType))
@@ -522,7 +522,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
{
val ret = cast(array, ArrayType(BooleanType, containsNull = true))
assert(ret.resolved === true)
- checkEvaluation(ret, Seq(true, true, false, null))
+ checkEvaluation(ret, Seq(null, true, false, null))
}
{
val ret = cast(array, ArrayType(BooleanType, containsNull = false))
@@ -541,12 +541,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
{
val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true))
assert(ret.resolved === true)
- checkEvaluation(ret, Seq(true, true, false))
+ checkEvaluation(ret, Seq(null, true, false))
}
{
val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false))
assert(ret.resolved === true)
- checkEvaluation(ret, Seq(true, true, false))
+ checkEvaluation(ret, Seq(null, true, false))
}
{
@@ -557,10 +557,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
test("cast from map") {
val map = Literal.create(
- Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null),
+ Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null),
MapType(StringType, StringType, valueContainsNull = true))
val map_notNull = Literal.create(
- Map("a" -> "123", "b" -> "abc", "c" -> ""),
+ Map("a" -> "123", "b" -> "true", "c" -> "f"),
MapType(StringType, StringType, valueContainsNull = false))
checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType))
@@ -577,7 +577,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
{
val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true))
assert(ret.resolved === true)
- checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null))
+ checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false, "d" -> null))
}
{
val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false))
@@ -600,12 +600,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
{
val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true))
assert(ret.resolved === true)
- checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false))
+ checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false))
}
{
val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false))
assert(ret.resolved === true)
- checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false))
+ checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false))
}
{
val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true))
@@ -630,8 +630,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
val struct = Literal.create(
InternalRow(
UTF8String.fromString("123"),
- UTF8String.fromString("abc"),
- UTF8String.fromString(""),
+ UTF8String.fromString("true"),
+ UTF8String.fromString("f"),
null),
StructType(Seq(
StructField("a", StringType, nullable = true),
@@ -641,8 +641,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
val struct_notNull = Literal.create(
InternalRow(
UTF8String.fromString("123"),
- UTF8String.fromString("abc"),
- UTF8String.fromString("")),
+ UTF8String.fromString("true"),
+ UTF8String.fromString("f")),
StructType(Seq(
StructField("a", StringType, nullable = false),
StructField("b", StringType, nullable = false),
@@ -672,7 +672,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("c", BooleanType, nullable = true),
StructField("d", BooleanType, nullable = true))))
assert(ret.resolved === true)
- checkEvaluation(ret, InternalRow(true, true, false, null))
+ checkEvaluation(ret, InternalRow(null, true, false, null))
}
{
val ret = cast(struct, StructType(Seq(
@@ -704,7 +704,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("b", BooleanType, nullable = true),
StructField("c", BooleanType, nullable = true))))
assert(ret.resolved === true)
- checkEvaluation(ret, InternalRow(true, true, false))
+ checkEvaluation(ret, InternalRow(null, true, false))
}
{
val ret = cast(struct_notNull, StructType(Seq(
@@ -712,7 +712,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("b", BooleanType, nullable = true),
StructField("c", BooleanType, nullable = false))))
assert(ret.resolved === true)
- checkEvaluation(ret, InternalRow(true, true, false))
+ checkEvaluation(ret, InternalRow(null, true, false))
}
{
@@ -731,8 +731,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
test("complex casting") {
val complex = Literal.create(
Row(
- Seq("123", "abc", ""),
- Map("a" ->"123", "b" -> "abc", "c" -> ""),
+ Seq("123", "true", "f"),
+ Map("a" ->"123", "b" -> "true", "c" -> "f"),
Row(0)),
StructType(Seq(
StructField("a",
@@ -755,11 +755,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(ret.resolved === true)
checkEvaluation(ret, Row(
Seq(123, null, null),
- Map("a" -> true, "b" -> true, "c" -> false),
+ Map("a" -> null, "b" -> true, "c" -> false),
Row(0L)))
}
- test("case between string and interval") {
+ test("cast between string and interval") {
import org.apache.spark.unsafe.types.CalendarInterval
checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType),
@@ -769,4 +769,23 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StringType),
"interval 1 years 3 months -3 days")
}
+
+ test("cast string to boolean") {
+ checkCast("t", true)
+ checkCast("true", true)
+ checkCast("tRUe", true)
+ checkCast("y", true)
+ checkCast("yes", true)
+ checkCast("1", true)
+
+ checkCast("f", false)
+ checkCast("false", false)
+ checkCast("FAlsE", false)
+ checkCast("n", false)
+ checkCast("no", false)
+ checkCast("0", false)
+
+ checkEvaluation(cast("abc", BooleanType), null)
+ checkEvaluation(cast("", BooleanType), null)
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index 13223c6158..8ffcef8566 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -375,6 +375,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
}
}
+ test("saveAsTable()/load() - partitioned table - boolean type") {
+ sqlContext.range(2)
+ .select('id, ('id % 2 === 0).as("b"))
+ .write.partitionBy("b").saveAsTable("t")
+
+ withTable("t") {
+ checkAnswer(
+ sqlContext.table("t").sort('id),
+ Row(0, true) :: Row(1, false) :: Nil
+ )
+ }
+ }
+
test("saveAsTable()/load() - partitioned table - Overwrite") {
partitionedTestDF.write
.format(dataSourceName)