aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2015-07-03 23:45:21 -0700
committerReynold Xin <rxin@databricks.com>2015-07-03 23:45:21 -0700
commitf35b0c3436898f22860d2c6c1d12f3a661005201 (patch)
tree2e2437cf66f651c62b521557c8fea4957f999c70 /sql
parentf32487b7ca86f768336a7c9b173f7c610fcde86f (diff)
downloadspark-f35b0c3436898f22860d2c6c1d12f3a661005201.tar.gz
spark-f35b0c3436898f22860d2c6c1d12f3a661005201.tar.bz2
spark-f35b0c3436898f22860d2c6c1d12f3a661005201.zip
[SPARK-8238][SPARK-8239][SPARK-8242][SPARK-8243][SPARK-8268][SQL]Add ascii/base64/unbase64/encode/decode functions
Add `ascii`,`base64`,`unbase64`,`encode` and `decode` expressions. Author: Cheng Hao <hao.cheng@intel.com> Closes #6843 from chenghao-intel/str_funcs2 and squashes the following commits: 78dee7d [Cheng Hao] base 64 -> base64 9d6f9f4 [Cheng Hao] remove the toString method for expressions ed5c19c [Cheng Hao] update code as comments 96170fc [Cheng Hao] scalastyle issues e2df768 [Cheng Hao] remove the unused import 491ce7b [Cheng Hao] add ascii/base64/unbase64/encode/decode functions
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala117
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala60
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala93
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala38
5 files changed, 308 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index a1299aed55..e249b58927 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -156,11 +156,16 @@ object FunctionRegistry {
expression[Sum]("sum"),
// string functions
+ expression[Ascii]("ascii"),
+ expression[Base64]("base64"),
+ expression[Encode]("encode"),
+ expression[Decode]("decode"),
expression[Lower]("lcase"),
expression[Lower]("lower"),
expression[StringLength]("length"),
expression[Substring]("substr"),
expression[Substring]("substring"),
+ expression[UnBase64]("unbase64"),
expression[Upper]("ucase"),
expression[UnHex]("unhex"),
expression[Upper]("upper"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 57918b32f8..154ac3508c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -298,3 +298,120 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI
override def prettyName: String = "length"
}
+
+/**
+ * Returns the numeric value of the first character of str.
+ */
+case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+ override def dataType: DataType = IntegerType
+ override def inputTypes: Seq[DataType] = Seq(StringType)
+
+ override def eval(input: InternalRow): Any = {
+ val string = child.eval(input)
+ if (string == null) {
+ null
+ } else {
+ val bytes = string.asInstanceOf[UTF8String].getBytes
+ if (bytes.length > 0) {
+ bytes(0).asInstanceOf[Int]
+ } else {
+ 0
+ }
+ }
+ }
+}
+
+/**
+ * Converts the argument from binary to a base 64 string.
+ */
+case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[DataType] = Seq(BinaryType)
+
+ override def eval(input: InternalRow): Any = {
+ val bytes = child.eval(input)
+ if (bytes == null) {
+ null
+ } else {
+ UTF8String.fromBytes(
+ org.apache.commons.codec.binary.Base64.encodeBase64(
+ bytes.asInstanceOf[Array[Byte]]))
+ }
+ }
+}
+
+/**
+ * Converts the argument from a base 64 string to BINARY.
+ */
+case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+ override def dataType: DataType = BinaryType
+ override def inputTypes: Seq[DataType] = Seq(StringType)
+
+ override def eval(input: InternalRow): Any = {
+ val string = child.eval(input)
+ if (string == null) {
+ null
+ } else {
+ org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString)
+ }
+ }
+}
+
+/**
+ * Decodes the first argument into a String using the provided character set
+ * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
+ * If either argument is null, the result will also be null. (As of Hive 0.12.0.).
+ */
+case class Decode(bin: Expression, charset: Expression) extends Expression with ExpectsInputTypes {
+ override def children: Seq[Expression] = bin :: charset :: Nil
+ override def foldable: Boolean = bin.foldable && charset.foldable
+ override def nullable: Boolean = bin.nullable || charset.nullable
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[DataType] = Seq(BinaryType, StringType)
+
+ override def eval(input: InternalRow): Any = {
+ val l = bin.eval(input)
+ if (l == null) {
+ null
+ } else {
+ val r = charset.eval(input)
+ if (r == null) {
+ null
+ } else {
+ val fromCharset = r.asInstanceOf[UTF8String].toString
+ UTF8String.fromString(new String(l.asInstanceOf[Array[Byte]], fromCharset))
+ }
+ }
+ }
+}
+
+/**
+ * Encodes the first argument into a BINARY using the provided character set
+ * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
+ * If either argument is null, the result will also be null. (As of Hive 0.12.0.)
+*/
+case class Encode(value: Expression, charset: Expression)
+ extends Expression with ExpectsInputTypes {
+ override def children: Seq[Expression] = value :: charset :: Nil
+ override def foldable: Boolean = value.foldable && charset.foldable
+ override def nullable: Boolean = value.nullable || charset.nullable
+ override def dataType: DataType = BinaryType
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
+
+ override def eval(input: InternalRow): Any = {
+ val l = value.eval(input)
+ if (l == null) {
+ null
+ } else {
+ val r = charset.eval(input)
+ if (r == null) {
+ null
+ } else {
+ val toCharset = r.asInstanceOf[UTF8String].toString
+ l.asInstanceOf[UTF8String].toString.getBytes(toCharset)
+ }
+ }
+ }
+}
+
+
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
index 5dbb1d562c..468df20442 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{IntegerType, StringType}
+import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType}
class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -217,11 +217,61 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("length for string") {
- val regEx = 'a.string.at(0)
+ val a = 'a.string.at(0)
checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef"))
- checkEvaluation(StringLength(regEx), 5, create_row("abdef"))
- checkEvaluation(StringLength(regEx), 0, create_row(""))
- checkEvaluation(StringLength(regEx), null, create_row(null))
+ checkEvaluation(StringLength(a), 5, create_row("abdef"))
+ checkEvaluation(StringLength(a), 0, create_row(""))
+ checkEvaluation(StringLength(a), null, create_row(null))
checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
}
+
+ test("ascii for string") {
+ val a = 'a.string.at(0)
+ checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef"))
+ checkEvaluation(Ascii(a), 97, create_row("abdef"))
+ checkEvaluation(Ascii(a), 0, create_row(""))
+ checkEvaluation(Ascii(a), null, create_row(null))
+ checkEvaluation(Ascii(Literal.create(null, StringType)), null, create_row("abdef"))
+ }
+
+ test("base64/unbase64 for string") {
+ val a = 'a.string.at(0)
+ val b = 'b.binary.at(0)
+ val bytes = Array[Byte](1, 2, 3, 4)
+
+ checkEvaluation(Base64(Literal(bytes)), "AQIDBA==", create_row("abdef"))
+ checkEvaluation(Base64(UnBase64(Literal("AQIDBA=="))), "AQIDBA==", create_row("abdef"))
+ checkEvaluation(Base64(UnBase64(Literal(""))), "", create_row("abdef"))
+ checkEvaluation(Base64(UnBase64(Literal.create(null, StringType))), null, create_row("abdef"))
+ checkEvaluation(Base64(UnBase64(a)), "AQIDBA==", create_row("AQIDBA=="))
+
+ checkEvaluation(Base64(b), "AQIDBA==", create_row(bytes))
+ checkEvaluation(Base64(b), "", create_row(Array[Byte]()))
+ checkEvaluation(Base64(b), null, create_row(null))
+ checkEvaluation(Base64(Literal.create(null, StringType)), null, create_row("abdef"))
+
+ checkEvaluation(UnBase64(a), null, create_row(null))
+ checkEvaluation(UnBase64(Literal.create(null, StringType)), null, create_row("abdef"))
+ }
+
+ test("encode/decode for string") {
+ val a = 'a.string.at(0)
+ val b = 'b.binary.at(0)
+ // scalastyle:off
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+ checkEvaluation(
+ Decode(Encode(Literal("大千世界"), Literal("UTF-16LE")), Literal("UTF-16LE")), "大千世界")
+ checkEvaluation(
+ Decode(Encode(a, Literal("utf-8")), Literal("utf-8")), "大千世界", create_row("大千世界"))
+ checkEvaluation(
+ Decode(Encode(a, Literal("utf-8")), Literal("utf-8")), "", create_row(""))
+ // scalastyle:on
+ checkEvaluation(Encode(a, Literal("utf-8")), null, create_row(null))
+ checkEvaluation(Encode(Literal.create(null, StringType), Literal("utf-8")), null)
+ checkEvaluation(Encode(a, Literal.create(null, StringType)), null, create_row(""))
+
+ checkEvaluation(Decode(b, Literal("utf-8")), null, create_row(null))
+ checkEvaluation(Decode(Literal.create(null, BinaryType), Literal("utf-8")), null)
+ checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 25e37ff67a..b63c6ee8ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1581,6 +1581,7 @@ object functions {
/**
* Computes the length of a given string value
+ *
* @group string_funcs
* @since 1.5.0
*/
@@ -1588,11 +1589,103 @@ object functions {
/**
* Computes the length of a given string column
+ *
* @group string_funcs
* @since 1.5.0
*/
def strlen(columnName: String): Column = strlen(Column(columnName))
+ /**
+ * Computes the numeric value of the first character of the specified string value.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def ascii(e: Column): Column = Ascii(e.expr)
+
+ /**
+ * Computes the numeric value of the first character of the specified string column.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def ascii(columnName: String): Column = ascii(Column(columnName))
+
+ /**
+ * Computes the specified value from binary to a base64 string.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def base64(e: Column): Column = Base64(e.expr)
+
+ /**
+ * Computes the specified column from binary to a base64 string.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def base64(columnName: String): Column = base64(Column(columnName))
+
+ /**
+ * Computes the specified value from a base64 string to binary.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def unbase64(e: Column): Column = UnBase64(e.expr)
+
+ /**
+ * Computes the specified column from a base64 string to binary.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def unbase64(columnName: String): Column = unbase64(Column(columnName))
+
+ /**
+ * Computes the first argument into a binary from a string using the provided character set
+ * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
+ * If either argument is null, the result will also be null.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def encode(value: Column, charset: Column): Column = Encode(value.expr, charset.expr)
+
+ /**
+ * Computes the first argument into a binary from a string using the provided character set
+ * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
+ * If either argument is null, the result will also be null.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def encode(columnName: String, charsetColumnName: String): Column =
+ encode(Column(columnName), Column(charsetColumnName))
+
+ /**
+ * Computes the first argument into a string from a binary using the provided character set
+ * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
+ * If either argument is null, the result will also be null.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def decode(value: Column, charset: Column): Column = Decode(value.expr, charset.expr)
+
+ /**
+ * Computes the first argument into a string from a binary using the provided character set
+ * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
+ * If either argument is null, the result will also be null.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def decode(columnName: String, charsetColumnName: String): Column =
+ decode(Column(columnName), Column(charsetColumnName))
+
+
//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 0d43aca877..bd9fa400e5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -225,4 +225,42 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(l)
})
}
+
+ test("string ascii function") {
+ val df = Seq(("abc", "")).toDF("a", "b")
+ checkAnswer(
+ df.select(ascii($"a"), ascii("b")),
+ Row(97, 0))
+
+ checkAnswer(
+ df.selectExpr("ascii(a)", "ascii(b)"),
+ Row(97, 0))
+ }
+
+ test("string base64/unbase64 function") {
+ val bytes = Array[Byte](1, 2, 3, 4)
+ val df = Seq((bytes, "AQIDBA==")).toDF("a", "b")
+ checkAnswer(
+ df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")),
+ Row("AQIDBA==", "AQIDBA==", bytes, bytes))
+
+ checkAnswer(
+ df.selectExpr("base64(a)", "unbase64(b)"),
+ Row("AQIDBA==", bytes))
+ }
+
+ test("string encode/decode function") {
+ val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116)
+ // scalastyle:off
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+ val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c")
+ checkAnswer(
+ df.select(encode($"a", $"b"), encode("a", "b"), decode($"c", $"b"), decode("c", "b")),
+ Row(bytes, bytes, "大千世界", "大千世界"))
+
+ checkAnswer(
+ df.selectExpr("encode(a, b)", "decode(c, b)"),
+ Row(bytes, "大千世界"))
+ // scalastyle:on
+ }
}