aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala28
1 files changed, 27 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 3c23f2ecfb..b60d318534 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -409,13 +409,14 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr:
* in given string after position pos.
*/
case class StringLocate(substr: Expression, str: Expression, start: Expression)
- extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback {
+ extends TernaryExpression with ImplicitCastInputTypes {
def this(substr: Expression, str: Expression) = {
this(substr, str, Literal(0))
}
override def children: Seq[Expression] = substr :: str :: start :: Nil
+ override def nullable: Boolean = substr.nullable || str.nullable
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
@@ -441,6 +442,31 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
}
}
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val substrGen = substr.gen(ctx)
+ val strGen = str.gen(ctx)
+ val startGen = start.gen(ctx)
+ s"""
+ int ${ev.primitive} = 0;
+ boolean ${ev.isNull} = false;
+ ${startGen.code}
+ if (!${startGen.isNull}) {
+ ${substrGen.code}
+ if (!${substrGen.isNull}) {
+ ${strGen.code}
+ if (!${strGen.isNull}) {
+ ${ev.primitive} = ${strGen.primitive}.indexOf(${substrGen.primitive},
+ ${startGen.primitive}) + 1;
+ } else {
+ ${ev.isNull} = true;
+ }
+ } else {
+ ${ev.isNull} = true;
+ }
+ }
+ """
+ }
+
override def prettyName: String = "locate"
}