aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWilliam Benton <willb@redhat.com>2014-07-15 14:11:57 -0700
committerMichael Armbrust <michael@databricks.com>2014-07-15 14:11:57 -0700
commit61de65bc69f9a5fc396b76713193c6415436d452 (patch)
treec22d4e12c4de2ffd72a7c294e2316d39aaca3f36 /sql
parent8af46d58464b96471825ce376c3e11c8b1108c0e (diff)
downloadspark-61de65bc69f9a5fc396b76713193c6415436d452.tar.gz
spark-61de65bc69f9a5fc396b76713193c6415436d452.tar.bz2
spark-61de65bc69f9a5fc396b76713193c6415436d452.zip
SPARK-2407: Added internal implementation of SQL SUBSTR()
This replaces the Hive UDF for SUBSTR(ING) with an implementation in Catalyst and adds tests to verify correct operation. Author: William Benton <willb@redhat.com> Closes #1359 from willb/internalSqlSubstring and squashes the following commits: ccedc47 [William Benton] Fixed too-long line. a30a037 [William Benton] replace view bounds with implicit parameters ec35c80 [William Benton] Adds fixes from review: 4f3bfdb [William Benton] Added internal implementation of SQL SUBSTR()
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala77
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala49
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala5
3 files changed, 128 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index b3850533c3..4bd7bf5a0c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.regex.Pattern
-import org.apache.spark.sql.catalyst.types.DataType
-import org.apache.spark.sql.catalyst.types.StringType
-import org.apache.spark.sql.catalyst.types.BooleanType
+import scala.collection.IndexedSeqOptimized
+
+
+import org.apache.spark.sql.catalyst.analysis.UnresolvedException
+import org.apache.spark.sql.catalyst.types.{BinaryType, BooleanType, DataType, StringType}
trait StringRegexExpression {
self: BinaryExpression =>
@@ -205,3 +207,72 @@ case class EndsWith(left: Expression, right: Expression)
extends BinaryExpression with StringComparison {
def compare(l: String, r: String) = l.endsWith(r)
}
+
+/**
+ * A function that takes a substring of its first argument starting at a given position.
+ * Defined for String and Binary types.
+ */
+case class Substring(str: Expression, pos: Expression, len: Expression) extends Expression {
+
+ type EvaluatedType = Any
+
+ def nullable: Boolean = true
+ def dataType: DataType = {
+ if (!resolved) {
+ throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved")
+ }
+ if (str.dataType == BinaryType) str.dataType else StringType
+ }
+
+ def references = children.flatMap(_.references).toSet
+
+ override def children = str :: pos :: len :: Nil
+
+ @inline
+ def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int)
+ (implicit ev: (C=>IndexedSeqOptimized[T,_])): Any = {
+ val len = str.length
+ // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
+ // negative indices for start positions. If a start index i is greater than 0, it
+ // refers to element i-1 in the sequence. If a start index i is less than 0, it refers
+ // to the -ith element before the end of the sequence. If a start index i is 0, it
+ // refers to the first element.
+
+ val start = startPos match {
+ case pos if pos > 0 => pos - 1
+ case neg if neg < 0 => len + neg
+ case _ => 0
+ }
+
+ val end = sliceLen match {
+ case max if max == Integer.MAX_VALUE => max
+ case x => start + x
+ }
+
+ str.slice(start, end)
+ }
+
+ override def eval(input: Row): Any = {
+ val string = str.eval(input)
+
+ val po = pos.eval(input)
+ val ln = len.eval(input)
+
+ if ((string == null) || (po == null) || (ln == null)) {
+ null
+ } else {
+ val start = po.asInstanceOf[Int]
+ val length = ln.asInstanceOf[Int]
+
+ string match {
+ case ba: Array[Byte] => slice(ba, start, length)
+ case other => slice(other.toString, start, length)
+ }
+ }
+ }
+
+ override def toString = len match {
+ case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)"
+ case _ => s"SUBSTR($str, $pos, $len)"
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 84d7281477..f1d7aedcc2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -466,5 +466,54 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(c1 === c2, false, row)
checkEvaluation(c1 !== c2, true, row)
}
+
+ test("Substring") {
+ val row = new GenericRow(Array[Any]("example", "example".toArray.map(_.toByte)))
+
+ val s = 'a.string.at(0)
+
+ // substring from zero position with less-than-full length
+ checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)), "ex", row)
+ checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(2, IntegerType)), "ex", row)
+
+ // substring from zero position with full length
+ checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(7, IntegerType)), "example", row)
+ checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(7, IntegerType)), "example", row)
+
+ // substring from zero position with greater-than-full length
+ checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(100, IntegerType)), "example", row)
+ checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(100, IntegerType)), "example", row)
+
+ // substring from nonzero position with less-than-full length
+ checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(2, IntegerType)), "xa", row)
+
+ // substring from nonzero position with full length
+ checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(6, IntegerType)), "xample", row)
+
+ // substring from nonzero position with greater-than-full length
+ checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(100, IntegerType)), "xample", row)
+
+ // zero-length substring (within string bounds)
+ checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(0, IntegerType)), "", row)
+
+ // zero-length substring (beyond string bounds)
+ checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), "", row)
+
+ // substring(null, _, _) -> null
+ checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), null, new GenericRow(Array[Any](null)))
+
+ // substring(_, null, _) -> null
+ checkEvaluation(Substring(s, Literal(null, IntegerType), Literal(4, IntegerType)), null, row)
+
+ // substring(_, _, null) -> null
+ checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(null, IntegerType)), null, row)
+
+ // 2-arg substring from zero position
+ checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row)
+ checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row)
+
+ // 2-arg substring from nonzero position
+ checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "xample", row)
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 56aa27a208..300e249f5b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -860,6 +860,7 @@ private[hive] object HiveQl {
val BETWEEN = "(?i)BETWEEN".r
val WHEN = "(?i)WHEN".r
val CASE = "(?i)CASE".r
+ val SUBSTR = "(?i)SUBSTR(?:ING)?".r
protected def nodeToExpr(node: Node): Expression = node match {
/* Attribute References */
@@ -984,6 +985,10 @@ private[hive] object HiveQl {
/* Other functions */
case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand
+ case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) =>
+ Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType))
+ case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) =>
+ Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length))
/* UDFs - Must be last otherwise will preempt built in functions */
case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>