aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala61
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala13
4 files changed, 82 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 2db954257b..f0bce388d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -22,7 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.{StringUtils, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -140,7 +140,15 @@ case class Cast(child: Expression, dataType: DataType)
// UDFToBoolean
private[this] def castToBoolean(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, _.numBytes() != 0)
+ buildCast[UTF8String](_, s => {
+ if (StringUtils.isTrueString(s)) {
+ true
+ } else if (StringUtils.isFalseString(s)) {
+ false
+ } else {
+ null
+ }
+ })
case TimestampType =>
buildCast[Long](_, t => t != 0)
case DateType =>
@@ -646,7 +654,17 @@ case class Cast(child: Expression, dataType: DataType)
private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
case StringType =>
- (c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;"
+ val stringUtils = StringUtils.getClass.getName.stripSuffix("$")
+ (c, evPrim, evNull) =>
+ s"""
+ if ($stringUtils.isTrueString($c)) {
+ $evPrim = true;
+ } else if ($stringUtils.isFalseString($c)) {
+ $evPrim = false;
+ } else {
+ $evNull = true;
+ }
+ """
case TimestampType =>
(c, evPrim, evNull) => s"$evPrim = $c != 0;"
case DateType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
index 9ddfb3a0d3..c2eeb3c565 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util
import java.util.regex.Pattern
+import org.apache.spark.unsafe.types.UTF8String
+
object StringUtils {
// replace the _ with .{1} exactly match 1 time of any character
@@ -44,4 +46,10 @@ object StringUtils {
v
}
}
+
+ private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString)
+ private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString)
+
+ def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
+ def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 1ad70733ea..f4db4da764 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -503,9 +503,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("cast from array") {
- val array = Literal.create(Seq("123", "abc", "", null),
+ val array = Literal.create(Seq("123", "true", "f", null),
ArrayType(StringType, containsNull = true))
- val array_notNull = Literal.create(Seq("123", "abc", ""),
+ val array_notNull = Literal.create(Seq("123", "true", "f"),
ArrayType(StringType, containsNull = false))
checkNullCast(ArrayType(StringType), ArrayType(IntegerType))
@@ -522,7 +522,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
{
val ret = cast(array, ArrayType(BooleanType, containsNull = true))
assert(ret.resolved === true)
- checkEvaluation(ret, Seq(true, true, false, null))
+ checkEvaluation(ret, Seq(null, true, false, null))
}
{
val ret = cast(array, ArrayType(BooleanType, containsNull = false))
@@ -541,12 +541,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
{
val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true))
assert(ret.resolved === true)
- checkEvaluation(ret, Seq(true, true, false))
+ checkEvaluation(ret, Seq(null, true, false))
}
{
val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false))
assert(ret.resolved === true)
- checkEvaluation(ret, Seq(true, true, false))
+ checkEvaluation(ret, Seq(null, true, false))
}
{
@@ -557,10 +557,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
test("cast from map") {
val map = Literal.create(
- Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null),
+ Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null),
MapType(StringType, StringType, valueContainsNull = true))
val map_notNull = Literal.create(
- Map("a" -> "123", "b" -> "abc", "c" -> ""),
+ Map("a" -> "123", "b" -> "true", "c" -> "f"),
MapType(StringType, StringType, valueContainsNull = false))
checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType))
@@ -577,7 +577,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
{
val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true))
assert(ret.resolved === true)
- checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null))
+ checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false, "d" -> null))
}
{
val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false))
@@ -600,12 +600,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
{
val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true))
assert(ret.resolved === true)
- checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false))
+ checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false))
}
{
val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false))
assert(ret.resolved === true)
- checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false))
+ checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false))
}
{
val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true))
@@ -630,8 +630,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
val struct = Literal.create(
InternalRow(
UTF8String.fromString("123"),
- UTF8String.fromString("abc"),
- UTF8String.fromString(""),
+ UTF8String.fromString("true"),
+ UTF8String.fromString("f"),
null),
StructType(Seq(
StructField("a", StringType, nullable = true),
@@ -641,8 +641,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
val struct_notNull = Literal.create(
InternalRow(
UTF8String.fromString("123"),
- UTF8String.fromString("abc"),
- UTF8String.fromString("")),
+ UTF8String.fromString("true"),
+ UTF8String.fromString("f")),
StructType(Seq(
StructField("a", StringType, nullable = false),
StructField("b", StringType, nullable = false),
@@ -672,7 +672,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("c", BooleanType, nullable = true),
StructField("d", BooleanType, nullable = true))))
assert(ret.resolved === true)
- checkEvaluation(ret, InternalRow(true, true, false, null))
+ checkEvaluation(ret, InternalRow(null, true, false, null))
}
{
val ret = cast(struct, StructType(Seq(
@@ -704,7 +704,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("b", BooleanType, nullable = true),
StructField("c", BooleanType, nullable = true))))
assert(ret.resolved === true)
- checkEvaluation(ret, InternalRow(true, true, false))
+ checkEvaluation(ret, InternalRow(null, true, false))
}
{
val ret = cast(struct_notNull, StructType(Seq(
@@ -712,7 +712,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StructField("b", BooleanType, nullable = true),
StructField("c", BooleanType, nullable = false))))
assert(ret.resolved === true)
- checkEvaluation(ret, InternalRow(true, true, false))
+ checkEvaluation(ret, InternalRow(null, true, false))
}
{
@@ -731,8 +731,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
test("complex casting") {
val complex = Literal.create(
Row(
- Seq("123", "abc", ""),
- Map("a" ->"123", "b" -> "abc", "c" -> ""),
+ Seq("123", "true", "f"),
+ Map("a" ->"123", "b" -> "true", "c" -> "f"),
Row(0)),
StructType(Seq(
StructField("a",
@@ -755,11 +755,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(ret.resolved === true)
checkEvaluation(ret, Row(
Seq(123, null, null),
- Map("a" -> true, "b" -> true, "c" -> false),
+ Map("a" -> null, "b" -> true, "c" -> false),
Row(0L)))
}
- test("case between string and interval") {
+ test("cast between string and interval") {
import org.apache.spark.unsafe.types.CalendarInterval
checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType),
@@ -769,4 +769,23 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
StringType),
"interval 1 years 3 months -3 days")
}
+
+ test("cast string to boolean") {
+ checkCast("t", true)
+ checkCast("true", true)
+ checkCast("tRUe", true)
+ checkCast("y", true)
+ checkCast("yes", true)
+ checkCast("1", true)
+
+ checkCast("f", false)
+ checkCast("false", false)
+ checkCast("FAlsE", false)
+ checkCast("n", false)
+ checkCast("no", false)
+ checkCast("0", false)
+
+ checkEvaluation(cast("abc", BooleanType), null)
+ checkEvaluation(cast("", BooleanType), null)
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index 13223c6158..8ffcef8566 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -375,6 +375,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
}
}
+ test("saveAsTable()/load() - partitioned table - boolean type") {
+ sqlContext.range(2)
+ .select('id, ('id % 2 === 0).as("b"))
+ .write.partitionBy("b").saveAsTable("t")
+
+ withTable("t") {
+ checkAnswer(
+ sqlContext.table("t").sort('id),
+ Row(0, true) :: Row(1, false) :: Nil
+ )
+ }
+ }
+
test("saveAsTable()/load() - partitioned table - Overwrite") {
partitionedTestDF.write
.format(dataSourceName)