aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorzhichao.li <zhichao.li@intel.com>2015-08-01 08:48:46 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-01 08:48:46 -0700
commitc5166f7a69faeaa8a41a774c73c1ed4d4c2cf0ce (patch)
treec1b3ddbb2f8697743cd8e11aaeedcdf80d1adec7 /sql
parentcf6c9ca32a89422e25007d333bc8714d9b0ae6d8 (diff)
downloadspark-c5166f7a69faeaa8a41a774c73c1ed4d4c2cf0ce.tar.gz
spark-c5166f7a69faeaa8a41a774c73c1ed4d4c2cf0ce.tar.bz2
spark-c5166f7a69faeaa8a41a774c73c1ed4d4c2cf0ce.zip
[SPARK-8263] [SQL] substr/substring should also support binary type
This is based on #7641, thanks to zhichao-li Closes #7641 Author: zhichao.li <zhichao.li@intel.com> Author: Davies Liu <davies@databricks.com> Closes #7848 from davies/substr and squashes the following commits: 461b709 [Davies Liu] remove bytearry from tests b45377a [Davies Liu] Merge branch 'master' of github.com:apache/spark into substr 01d795e [zhichao.li] scala style 99aa130 [zhichao.li] add substring to dataframe 4f68bfe [zhichao.li] add binary type support for substring
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala51
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala10
4 files changed, 81 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 3ce5d6a9c7..4d78c55497 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import java.text.DecimalFormat
+import java.util.Arrays
import java.util.Locale
import java.util.regex.{MatchResult, Pattern}
@@ -679,6 +680,34 @@ case class StringSplit(str: Expression, pattern: Expression)
override def prettyName: String = "split"
}
+object Substring {
+ def subStringBinarySQL(bytes: Array[Byte], pos: Int, len: Int): Array[Byte] = {
+ if (pos > bytes.length) {
+ return Array[Byte]()
+ }
+
+ var start = if (pos > 0) {
+ pos - 1
+ } else if (pos < 0) {
+ bytes.length + pos
+ } else {
+ 0
+ }
+
+ val end = if ((bytes.length - start) < len) {
+ bytes.length
+ } else {
+ start + len
+ }
+
+ start = Math.max(start, 0) // underflow
+ if (start < end) {
+ Arrays.copyOfRange(bytes, start, end)
+ } else {
+ Array[Byte]()
+ }
+ }
+}
/**
* A function that takes a substring of its first argument starting at a given position.
* Defined for String and Binary types.
@@ -690,18 +719,31 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
this(str, pos, Literal(Integer.MAX_VALUE))
}
- override def dataType: DataType = StringType
+ override def dataType: DataType = str.dataType
- override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(TypeCollection(StringType, BinaryType), IntegerType, IntegerType)
override def children: Seq[Expression] = str :: pos :: len :: Nil
override def nullSafeEval(string: Any, pos: Any, len: Any): Any = {
- string.asInstanceOf[UTF8String].substringSQL(pos.asInstanceOf[Int], len.asInstanceOf[Int])
+ str.dataType match {
+ case StringType => string.asInstanceOf[UTF8String]
+ .substringSQL(pos.asInstanceOf[Int], len.asInstanceOf[Int])
+ case BinaryType => Substring.subStringBinarySQL(string.asInstanceOf[Array[Byte]],
+ pos.asInstanceOf[Int], len.asInstanceOf[Int])
+ }
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- defineCodeGen(ctx, ev, (str, pos, len) => s"$str.substringSQL($pos, $len)")
+
+ val cls = classOf[Substring].getName
+ defineCodeGen(ctx, ev, (string, pos, len) => {
+ str.dataType match {
+ case StringType => s"$string.substringSQL($pos, $len)"
+ case BinaryType => s"$cls.subStringBinarySQL($string, $pos, $len)"
+ }
+ })
}
}
@@ -1161,4 +1203,3 @@ case class FormatNumber(x: Expression, d: Expression)
override def prettyName: String = "format_number"
}
-
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index ad87ab36fd..89c1e33420 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -186,6 +185,20 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(s.substr(0), "example", row)
checkEvaluation(s.substring(0, 2), "ex", row)
checkEvaluation(s.substring(0), "example", row)
+
+ val bytes = Array[Byte](1, 2, 3, 4)
+ checkEvaluation(Substring(bytes, 0, 2), Array[Byte](1, 2))
+ checkEvaluation(Substring(bytes, 1, 2), Array[Byte](1, 2))
+ checkEvaluation(Substring(bytes, 2, 2), Array[Byte](2, 3))
+ checkEvaluation(Substring(bytes, 3, 2), Array[Byte](3, 4))
+ checkEvaluation(Substring(bytes, 4, 2), Array[Byte](4))
+ checkEvaluation(Substring(bytes, 8, 2), Array[Byte]())
+ checkEvaluation(Substring(bytes, -1, 2), Array[Byte](4))
+ checkEvaluation(Substring(bytes, -2, 2), Array[Byte](3, 4))
+ checkEvaluation(Substring(bytes, -3, 2), Array[Byte](2, 3))
+ checkEvaluation(Substring(bytes, -4, 2), Array[Byte](1, 2))
+ checkEvaluation(Substring(bytes, -5, 2), Array[Byte](1))
+ checkEvaluation(Substring(bytes, -8, 2), Array[Byte]())
}
test("string substring_index function") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 3c9421f5cd..babfe21879 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1726,6 +1726,17 @@ object functions {
def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr)
/**
+ * Substring starts at `pos` and is of length `len` when str is String type or
+ * returns the slice of byte array that starts at `pos` in byte and is of length `len`
+ * when str is Binary type
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def substring(str: Column, pos: Int, len: Int): Column =
+ Substring(str.expr, lit(pos).expr, lit(len).expr)
+
+ /**
* Computes the Levenshtein distance of the two given string columns.
* @group string_funcs
* @since 1.5.0
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 628da95298..f40233db0a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -103,6 +103,16 @@ class StringFunctionsSuite extends QueryTest {
Row("AQIDBA==", bytes))
}
+ test("string / binary substring function") {
+ // scalastyle:off
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+ val df = Seq(("1世3", Array[Byte](1, 2, 3, 4))).toDF("a", "b")
+ checkAnswer(df.select(substring($"a", 1, 2)), Row("1世"))
+ checkAnswer(df.select(substring($"b", 2, 2)), Row(Array[Byte](2,3)))
+ checkAnswer(df.selectExpr("substring(a, 1, 2)"), Row("1世"))
+ // scalastyle:on
+ }
+
test("string encode/decode function") {
val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116)
// scalastyle:off