aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTarek Auel <tarek.auel@googlemail.com>2015-07-20 22:43:30 -0700
committerDavies Liu <davies.liu@gmail.com>2015-07-20 22:44:21 -0700
commit560b355ccd038ca044726c9c9fcffd14d02e6696 (patch)
treeaeb1e8031bc3ca57a410c3d8c5d5f27cdc82e6a7
parentc032b0bf92130dc4facb003f0deaeb1228aefded (diff)
downloadspark-560b355ccd038ca044726c9c9fcffd14d02e6696.tar.gz
spark-560b355ccd038ca044726c9c9fcffd14d02e6696.tar.bz2
spark-560b355ccd038ca044726c9c9fcffd14d02e6696.zip
[SPARK-9157] [SQL] codegen substring
https://issues.apache.org/jira/browse/SPARK-9157 Author: Tarek Auel <tarek.auel@googlemail.com> Closes #7534 from tarekauel/SPARK-9157 and squashes the following commits: e65e3e9 [Tarek Auel] [SPARK-9157] indent fix 44e89f8 [Tarek Auel] [SPARK-9157] use EMPTY_UTF8 37d54c4 [Tarek Auel] Merge branch 'master' into SPARK-9157 60732ea [Tarek Auel] [SPARK-9157] created substringSQL in UTF8String 18c3576 [Tarek Auel] [SPARK-9157][SQL] remove slice pos 1a2e611 [Tarek Auel] [SPARK-9157][SQL] codegen substring
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala87
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java12
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java19
3 files changed, 75 insertions, 43 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 5c1908d555..438215e8e6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -640,7 +640,7 @@ case class StringSplit(str: Expression, pattern: Expression)
* Defined for String and Binary types.
*/
case class Substring(str: Expression, pos: Expression, len: Expression)
- extends Expression with ImplicitCastInputTypes with CodegenFallback {
+ extends Expression with ImplicitCastInputTypes {
def this(str: Expression, pos: Expression) = {
this(str, pos, Literal(Integer.MAX_VALUE))
@@ -649,58 +649,59 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
override def foldable: Boolean = str.foldable && pos.foldable && len.foldable
override def nullable: Boolean = str.nullable || pos.nullable || len.nullable
- override def dataType: DataType = {
- if (!resolved) {
- throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved")
- }
- if (str.dataType == BinaryType) str.dataType else StringType
- }
+ override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
override def children: Seq[Expression] = str :: pos :: len :: Nil
- @inline
- def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = {
- // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
- // negative indices for start positions. If a start index i is greater than 0, it
- // refers to element i-1 in the sequence. If a start index i is less than 0, it refers
- // to the -ith element before the end of the sequence. If a start index i is 0, it
- // refers to the first element.
-
- val start = startPos match {
- case pos if pos > 0 => pos - 1
- case neg if neg < 0 => length() + neg
- case _ => 0
- }
-
- val end = sliceLen match {
- case max if max == Integer.MAX_VALUE => max
- case x => start + x
+ override def eval(input: InternalRow): Any = {
+ val stringEval = str.eval(input)
+ if (stringEval != null) {
+ val posEval = pos.eval(input)
+ if (posEval != null) {
+ val lenEval = len.eval(input)
+ if (lenEval != null) {
+ stringEval.asInstanceOf[UTF8String]
+ .substringSQL(posEval.asInstanceOf[Int], lenEval.asInstanceOf[Int])
+ } else {
+ null
+ }
+ } else {
+ null
+ }
+ } else {
+ null
}
-
- (start, end)
}
- override def eval(input: InternalRow): Any = {
- val string = str.eval(input)
- val po = pos.eval(input)
- val ln = len.eval(input)
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val strGen = str.gen(ctx)
+ val posGen = pos.gen(ctx)
+ val lenGen = len.gen(ctx)
- if ((string == null) || (po == null) || (ln == null)) {
- null
- } else {
- val start = po.asInstanceOf[Int]
- val length = ln.asInstanceOf[Int]
- string match {
- case ba: Array[Byte] =>
- val (st, end) = slicePos(start, length, () => ba.length)
- ba.slice(st, end)
- case s: UTF8String =>
- val (st, end) = slicePos(start, length, () => s.numChars())
- s.substring(st, end)
+ val start = ctx.freshName("start")
+ val end = ctx.freshName("end")
+
+ s"""
+ ${strGen.code}
+ boolean ${ev.isNull} = ${strGen.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${posGen.code}
+ if (!${posGen.isNull}) {
+ ${lenGen.code}
+ if (!${lenGen.isNull}) {
+ ${ev.primitive} = ${strGen.primitive}
+ .substringSQL(${posGen.primitive}, ${lenGen.primitive});
+ } else {
+ ${ev.isNull} = true;
+ }
+ } else {
+ ${ev.isNull} = true;
+ }
}
- }
+ """
}
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index ed354f7f87..946d355f1f 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -165,6 +165,18 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
return fromBytes(bytes);
}
+ public UTF8String substringSQL(int pos, int length) {
+ // Information regarding the pos calculation:
+ // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and
+ // negative indices for start positions. If a start index i is greater than 0, it
+ // refers to element i-1 in the sequence. If a start index i is less than 0, it refers
+ // to the -ith element before the end of the sequence. If a start index i is 0, it
+ // refers to the first element.
+ int start = (pos > 0) ? pos -1 : ((pos < 0) ? numChars() + pos : 0);
+ int end = (length == Integer.MAX_VALUE) ? Integer.MAX_VALUE : start + length;
+ return substring(start, end);
+ }
+
/**
* Returns whether this contains `substring` or not.
*/
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 1f5572c509..e2a5628ff4 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -273,6 +273,25 @@ public class UTF8StringSuite {
}
@Test
+ public void substringSQL() {
+ UTF8String e = fromString("example");
+ assertEquals(e.substringSQL(0, 2), fromString("ex"));
+ assertEquals(e.substringSQL(1, 2), fromString("ex"));
+ assertEquals(e.substringSQL(0, 7), fromString("example"));
+ assertEquals(e.substringSQL(1, 2), fromString("ex"));
+ assertEquals(e.substringSQL(0, 100), fromString("example"));
+ assertEquals(e.substringSQL(1, 100), fromString("example"));
+ assertEquals(e.substringSQL(2, 2), fromString("xa"));
+ assertEquals(e.substringSQL(1, 6), fromString("exampl"));
+ assertEquals(e.substringSQL(2, 100), fromString("xample"));
+ assertEquals(e.substringSQL(0, 0), fromString(""));
+ assertEquals(e.substringSQL(100, 4), EMPTY_UTF8);
+ assertEquals(e.substringSQL(0, Integer.MAX_VALUE), fromString("example"));
+ assertEquals(e.substringSQL(1, Integer.MAX_VALUE), fromString("example"));
+ assertEquals(e.substringSQL(2, Integer.MAX_VALUE), fromString("xample"));
+ }
+
+ @Test
public void split() {
assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), -1),
new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi")}));