diff options
author | Davies Liu <davies@databricks.com> | 2015-07-06 13:31:31 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-07-06 13:31:31 -0700 |
commit | 37e4d92142a6309e2df7d36883e0c7892c3d792d (patch) | |
tree | 20a23e5a565ebe20919cb20be03e9335d4088dc8 | |
parent | 57c72fcce75907c08a1ae53a0d85447176fc3c69 (diff) | |
download | spark-37e4d92142a6309e2df7d36883e0c7892c3d792d.tar.gz spark-37e4d92142a6309e2df7d36883e0c7892c3d792d.tar.bz2 spark-37e4d92142a6309e2df7d36883e0c7892c3d792d.zip |
[SPARK-8784] [SQL] Add Python API for hex and unhex
Add Python API for hex/unhex, also cleanup Hex/Unhex
Author: Davies Liu <davies@databricks.com>
Closes #7223 from davies/hex and squashes the following commits:
6f1249d [Davies Liu] no explicit rule to cast string into binary
711a6ed [Davies Liu] fix test
f9fe5a3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into hex
f032fbb [Davies Liu] Merge branch 'hex' of github.com:davies/spark into hex
49e325f [Davies Liu] Merge branch 'master' of github.com:apache/spark into hex
b31fc9a [Davies Liu] Update math.scala
25156b7 [Davies Liu] address comments and fix test
c3af78c [Davies Liu] address commments
1a24082 [Davies Liu] Add Python API for hex and unhex
5 files changed, 93 insertions, 47 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 49dd0332af..dca39fa833 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -397,6 +397,34 @@ def randn(seed=None): @ignore_unicode_prefix @since(1.5) +def hex(col): + """Computes hex value of the given column, which could be StringType, + BinaryType, IntegerType or LongType. + + >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() + [Row(hex(a)=u'414243', hex(b)=u'3')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.hex(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def unhex(col): + """Inverse of hex. Interprets each pair of characters as a hexadecimal number + and converts to the byte representation of number. + + >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() + [Row(unhex(a)=bytearray(b'ABC'))] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.unhex(_to_java_column(col)) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) def sha1(col): """Returns the hex string result of SHA-1. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 92a50e7092..fef2763530 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -168,7 +168,7 @@ object FunctionRegistry { expression[Substring]("substring"), expression[UnBase64]("unbase64"), expression[Upper]("ucase"), - expression[UnHex]("unhex"), + expression[Unhex]("unhex"), expression[Upper]("upper"), // datetime functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 45b7e4d340..9250045398 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -298,6 +298,21 @@ case class Bin(child: Expression) } } +object Hex { + val hexDigits = Array[Char]( + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' + ).map(_.toByte) + + // lookup table to translate '0' -> 0 ... 'F'/'f' -> 15 + val unhexDigits = { + val array = Array.fill[Byte](128)(-1) + (0 to 9).foreach(i => array('0' + i) = i.toByte) + (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) + (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) + array + } +} + /** * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. * Otherwise if the number is a STRING, it converts each character into its hex representation @@ -307,7 +322,7 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes // TODO: Create code-gen version. override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, StringType, BinaryType)) + Seq(TypeCollection(LongType, BinaryType, StringType)) override def dataType: DataType = StringType @@ -319,30 +334,18 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes child.dataType match { case LongType => hex(num.asInstanceOf[Long]) case BinaryType => hex(num.asInstanceOf[Array[Byte]]) - case StringType => hex(num.asInstanceOf[UTF8String]) + case StringType => hex(num.asInstanceOf[UTF8String].getBytes) } } } - /** - * Converts every character in s to two hex digits. - */ - private def hex(str: UTF8String): UTF8String = { - hex(str.getBytes) - } - - private def hex(bytes: Array[Byte]): UTF8String = { - doHex(bytes, bytes.length) - } - - private def doHex(bytes: Array[Byte], length: Int): UTF8String = { + private[this] def hex(bytes: Array[Byte]): UTF8String = { + val length = bytes.length val value = new Array[Byte](length * 2) var i = 0 while (i < length) { - value(i * 2) = Character.toUpperCase(Character.forDigit( - (bytes(i) & 0xF0) >>> 4, 16)).toByte - value(i * 2 + 1) = Character.toUpperCase(Character.forDigit( - bytes(i) & 0x0F, 16)).toByte + value(i * 2) = Hex.hexDigits((bytes(i) & 0xF0) >> 4) + value(i * 2 + 1) = Hex.hexDigits(bytes(i) & 0x0F) i += 1 } UTF8String.fromBytes(value) @@ -355,24 +358,23 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes var len = 0 do { len += 1 - value(value.length - len) = - Character.toUpperCase(Character.forDigit((numBuf & 0xF).toInt, 16)).toByte + value(value.length - len) = Hex.hexDigits((numBuf & 0xF).toInt) numBuf >>>= 4 } while (numBuf != 0) UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length)) } } - /** * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. */ -case class UnHex(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTypes { // TODO: Create code-gen version. override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + override def nullable: Boolean = true override def dataType: DataType = BinaryType override def eval(input: InternalRow): Any = { @@ -384,26 +386,31 @@ case class UnHex(child: Expression) extends UnaryExpression with ExpectsInputTyp } } - private val unhexDigits = { - val array = Array.fill[Byte](128)(-1) - (0 to 9).foreach(i => array('0' + i) = i.toByte) - (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte) - (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) - array - } - - private def unhex(inputBytes: Array[Byte]): Array[Byte] = { - var bytes = inputBytes + private[this] def unhex(bytes: Array[Byte]): Array[Byte] = { + val out = new Array[Byte]((bytes.length + 1) >> 1) + var i = 0 if ((bytes.length & 0x01) != 0) { - bytes = '0'.toByte +: bytes + // padding with '0' + if (bytes(0) < 0) { + return null + } + val v = Hex.unhexDigits(bytes(0)) + if (v == -1) { + return null + } + out(0) = v + i += 1 } - val out = new Array[Byte](bytes.length >> 1) // two characters form the hex value. - var i = 0 while (i < bytes.length) { - val first = unhexDigits(bytes(i)) - val second = unhexDigits(bytes(i + 1)) - if (first == -1 || second == -1) { return null} + if (bytes(i) < 0 || bytes(i + 1) < 0) { + return null + } + val first = Hex.unhexDigits(bytes(i)) + val second = Hex.unhexDigits(bytes(i + 1)) + if (first == -1 || second == -1) { + return null + } out(i / 2) = (((first << 4) | second) & 0xFF).toByte i += 2 } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 03d8400cf3..7ca9e30b2b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -21,8 +21,7 @@ import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{DataType, LongType} -import org.apache.spark.sql.types.{IntegerType, DoubleType} +import org.apache.spark.sql.types._ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -271,20 +270,32 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("hex") { + checkEvaluation(Hex(Literal.create(null, LongType)), null) + checkEvaluation(Hex(Literal(28L)), "1C") + checkEvaluation(Hex(Literal(-28L)), "FFFFFFFFFFFFFFE4") checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") - checkEvaluation(Hex(Literal("helloHex")), "68656C6C6F486578") + checkEvaluation(Hex(Literal.create(null, BinaryType)), null) checkEvaluation(Hex(Literal("helloHex".getBytes())), "68656C6C6F486578") // scalastyle:off // Turn off scala style for non-ascii chars - checkEvaluation(Hex(Literal("三重的")), "E4B889E9878DE79A84") + checkEvaluation(Hex(Literal("三重的".getBytes("UTF8"))), "E4B889E9878DE79A84") // scalastyle:on } test("unhex") { - checkEvaluation(UnHex(Literal("737472696E67")), "string".getBytes) - checkEvaluation(UnHex(Literal("")), new Array[Byte](0)) - checkEvaluation(UnHex(Literal("0")), Array[Byte](0)) + checkEvaluation(Unhex(Literal.create(null, StringType)), null) + checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes) + checkEvaluation(Unhex(Literal("")), new Array[Byte](0)) + checkEvaluation(Unhex(Literal("F")), Array[Byte](15)) + checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1)) + checkEvaluation(Unhex(Literal("GG")), null) + // scalastyle:off + // Turn off scala style for non-ascii chars + checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes("UTF-8")) + checkEvaluation(Unhex(Literal("三重的")), null) + + // scalastyle:on } test("hypot") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f80291776f..4da9ffc495 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1095,7 +1095,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def unhex(column: Column): Column = UnHex(column.expr) + def unhex(column: Column): Column = Unhex(column.expr) /** * Inverse of hex. Interprets each pair of characters as a hexadecimal number |