aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDaoyuan Wang <daoyuan.wang@intel.com>2016-05-23 23:29:15 -0700
committerAndrew Or <andrew@databricks.com>2016-05-23 23:29:15 -0700
commitd642b273544bb77ef7f584326aa2d214649ac61b (patch)
treee2bf63cd2c378d285165a7bf5f829dad93322efe /sql
parentde726b0d533158d3ca08841bd6976bcfa26ca79d (diff)
downloadspark-d642b273544bb77ef7f584326aa2d214649ac61b.tar.gz
spark-d642b273544bb77ef7f584326aa2d214649ac61b.tar.bz2
spark-d642b273544bb77ef7f584326aa2d214649ac61b.zip
[SPARK-15397][SQL] fix string udf locate as hive
## What changes were proposed in this pull request? in hive, `locate("aa", "aaa", 0)` would yield 0, `locate("aa", "aaa", 1)` would yield 1 and `locate("aa", "aaa", 2)` would yield 2, while in Spark, `locate("aa", "aaa", 0)` would yield 1, `locate("aa", "aaa", 1)` would yield 2 and `locate("aa", "aaa", 2)` would yield 0. This results from the different understanding of the third parameter in udf `locate`. It means the starting index and starts from 1, so when we use 0, the return would always be 0. ## How was this patch tested? tested with modified `StringExpressionsSuite` and `StringFunctionsSuite` Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #13186 from adrian-wang/locate.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala19
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala10
3 files changed, 27 insertions, 18 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 78e846d3f5..44ff7fda8e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -494,7 +494,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
def this(substr: Expression, str: Expression) = {
- this(substr, str, Literal(0))
+ this(substr, str, Literal(1))
}
override def children: Seq[Expression] = substr :: str :: start :: Nil
@@ -516,9 +516,14 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
if (l == null) {
null
} else {
- l.asInstanceOf[UTF8String].indexOf(
- r.asInstanceOf[UTF8String],
- s.asInstanceOf[Int]) + 1
+ val sVal = s.asInstanceOf[Int]
+ if (sVal < 1) {
+ 0
+ } else {
+ l.asInstanceOf[UTF8String].indexOf(
+ r.asInstanceOf[UTF8String],
+ s.asInstanceOf[Int] - 1) + 1
+ }
}
}
}
@@ -537,8 +542,10 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
if (!${substrGen.isNull}) {
${strGen.code}
if (!${strGen.isNull}) {
- ${ev.value} = ${strGen.value}.indexOf(${substrGen.value},
- ${startGen.value}) + 1;
+ if (${startGen.value} > 0) {
+ ${ev.value} = ${strGen.value}.indexOf(${substrGen.value},
+ ${startGen.value} - 1) + 1;
+ }
} else {
${ev.isNull} = true;
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index c09c64fd6b..29bf15bf52 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -508,16 +508,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val s2 = 'b.string.at(1)
val s3 = 'c.string.at(2)
val s4 = 'd.int.at(3)
- val row1 = create_row("aaads", "aa", "zz", 1)
- val row2 = create_row(null, "aa", "zz", 0)
- val row3 = create_row("aaads", null, "zz", 0)
- val row4 = create_row(null, null, null, 0)
+ val row1 = create_row("aaads", "aa", "zz", 2)
+ val row2 = create_row(null, "aa", "zz", 1)
+ val row3 = create_row("aaads", null, "zz", 1)
+ val row4 = create_row(null, null, null, 1)
checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1)
- checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1)
- checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 0, row1)
+ checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(0)), 0, row1)
+ checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 1, row1)
+ checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 2, row1)
+ checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(3)), 0, row1)
checkEvaluation(new StringLocate(Literal("de"), Literal("aaads")), 0, row1)
- checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 1), 0, row1)
+ checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 2), 0, row1)
checkEvaluation(new StringLocate(s2, s1), 1, row1)
checkEvaluation(StringLocate(s2, s1, s4), 2, row1)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index c7b95c2683..1de2d9b5ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -189,15 +189,15 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
}
test("string locate function") {
- val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d")
+ val df = Seq(("aaads", "aa", "zz", 2)).toDF("a", "b", "c", "d")
checkAnswer(
- df.select(locate("aa", $"a"), locate("aa", $"a", 1)),
- Row(1, 2))
+ df.select(locate("aa", $"a"), locate("aa", $"a", 2), locate("aa", $"a", 0)),
+ Row(1, 2, 0))
checkAnswer(
- df.selectExpr("locate(b, a)", "locate(b, a, d)"),
- Row(1, 2))
+ df.selectExpr("locate(b, a)", "locate(b, a, d)", "locate(b, a, 3)"),
+ Row(1, 2, 0))
}
test("string padding functions") {