aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-06-25 22:07:37 -0700
committerDavies Liu <davies@databricks.com>2015-06-25 22:07:37 -0700
commit47c874babe7779c7a2f32e0b891503ef6bebcab0 (patch)
tree40a848f19d98db6c01cd2dccfe716a36f5c491fd /sql
parentc392a9efabcb1ec2a2c53f001ecdae33c245ba35 (diff)
downloadspark-47c874babe7779c7a2f32e0b891503ef6bebcab0.tar.gz
spark-47c874babe7779c7a2f32e0b891503ef6bebcab0.tar.bz2
spark-47c874babe7779c7a2f32e0b891503ef6bebcab0.zip
[SPARK-8237] [SQL] Add misc function sha2
JIRA: https://issues.apache.org/jira/browse/SPARK-8237 Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #6934 from viirya/expr_sha2 and squashes the following commits: 35e0bb3 [Liang-Chi Hsieh] For comments. 68b5284 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_sha2 8573aff [Liang-Chi Hsieh] Remove unnecessary Product. ee61e06 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_sha2 59e41aa [Liang-Chi Hsieh] Add misc function: sha2.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala98
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala17
5 files changed, 146 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 5fb3369f85..457948a800 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -135,6 +135,7 @@ object FunctionRegistry {
// misc functions
expression[Md5]("md5"),
+ expression[Sha2]("sha2"),
// aggregate functions
expression[Average]("avg"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 4bee8cb728..e80706fc65 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -17,9 +17,12 @@
package org.apache.spark.sql.catalyst.expressions
+import java.security.MessageDigest
+import java.security.NoSuchAlgorithmException
+
import org.apache.commons.codec.digest.DigestUtils
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.types.{BinaryType, StringType, DataType}
+import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType, DataType}
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -44,7 +47,96 @@ case class Md5(child: Expression)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c =>
- "org.apache.spark.unsafe.types.UTF8String.fromString" +
- s"(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))")
+ s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))")
+ }
+}
+
+/**
+ * A function that calculates the SHA-2 family of functions (SHA-224, SHA-256, SHA-384, and SHA-512)
+ * and returns it as a hex string. The first argument is the string or binary to be hashed. The
+ * second argument indicates the desired bit length of the result, which must have a value of 224,
+ * 256, 384, 512, or 0 (which is equivalent to 256). SHA-224 is supported starting from Java 8. If
+ * asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or
+ * the hash length is not one of the permitted values, the return value is NULL.
+ */
+case class Sha2(left: Expression, right: Expression)
+ extends BinaryExpression with Serializable with ExpectsInputTypes {
+
+ override def dataType: DataType = StringType
+
+ override def toString: String = s"SHA2($left, $right)"
+
+ override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
+
+ override def eval(input: InternalRow): Any = {
+ val evalE1 = left.eval(input)
+ if (evalE1 == null) {
+ null
+ } else {
+ val evalE2 = right.eval(input)
+ if (evalE2 == null) {
+ null
+ } else {
+ val bitLength = evalE2.asInstanceOf[Int]
+ val input = evalE1.asInstanceOf[Array[Byte]]
+ bitLength match {
+ case 224 =>
+ // DigestUtils doesn't support SHA-224 now
+ try {
+ val md = MessageDigest.getInstance("SHA-224")
+ md.update(input)
+ UTF8String.fromBytes(md.digest())
+ } catch {
+ // SHA-224 is not supported on the system, return null
+ case noa: NoSuchAlgorithmException => null
+ }
+ case 256 | 0 =>
+ UTF8String.fromString(DigestUtils.sha256Hex(input))
+ case 384 =>
+ UTF8String.fromString(DigestUtils.sha384Hex(input))
+ case 512 =>
+ UTF8String.fromString(DigestUtils.sha512Hex(input))
+ case _ => null
+ }
+ }
+ }
+ }
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val eval1 = left.gen(ctx)
+ val eval2 = right.gen(ctx)
+ val digestUtils = "org.apache.commons.codec.digest.DigestUtils"
+
+ s"""
+ ${eval1.code}
+ boolean ${ev.isNull} = ${eval1.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${eval2.code}
+ if (!${eval2.isNull}) {
+ if (${eval2.primitive} == 224) {
+ try {
+ java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224");
+ md.update(${eval1.primitive});
+ ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest());
+ } catch (java.security.NoSuchAlgorithmException e) {
+ ${ev.isNull} = true;
+ }
+ } else if (${eval2.primitive} == 256 || ${eval2.primitive} == 0) {
+ ${ev.primitive} =
+ ${ctx.stringType}.fromString(${digestUtils}.sha256Hex(${eval1.primitive}));
+ } else if (${eval2.primitive} == 384) {
+ ${ev.primitive} =
+ ${ctx.stringType}.fromString(${digestUtils}.sha384Hex(${eval1.primitive}));
+ } else if (${eval2.primitive} == 512) {
+ ${ev.primitive} =
+ ${ctx.stringType}.fromString(${digestUtils}.sha512Hex(${eval1.primitive}));
+ } else {
+ ${ev.isNull} = true;
+ }
+ } else {
+ ${ev.isNull} = true;
+ }
+ }
+ """
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
index 48b84130b4..38482c54c6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
@@ -17,8 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.commons.codec.digest.DigestUtils
+
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.types.{StringType, BinaryType}
+import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType}
class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -29,4 +31,14 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
}
+ test("sha2") {
+ checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC"))
+ checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
+ DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6)))
+ // unsupported bit length
+ checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null)
+ checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null)
+ checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null)
+ checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 38d9085a50..355ce0e342 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1414,6 +1414,26 @@ object functions {
*/
def md5(columnName: String): Column = md5(Column(columnName))
+ /**
+ * Calculates the SHA-2 family of hash functions and returns the value as a hex string.
+ *
+ * @group misc_funcs
+ * @since 1.5.0
+ */
+ def sha2(e: Column, numBits: Int): Column = {
+ require(Seq(0, 224, 256, 384, 512).contains(numBits),
+ s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)")
+ Sha2(e.expr, lit(numBits).expr)
+ }
+
+ /**
+ * Calculates the SHA-2 family of hash functions and returns the value as a hex string.
+ *
+ * @group misc_funcs
+ * @since 1.5.0
+ */
+ def sha2(columnName: String, numBits: Int): Column = sha2(Column(columnName), numBits)
+
//////////////////////////////////////////////////////////////////////////////////////////////
// String functions
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 8b53b384a2..8baed57a7f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -144,6 +144,23 @@ class DataFrameFunctionsSuite extends QueryTest {
Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c"))
}
+ test("misc sha2 function") {
+ val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b")
+ checkAnswer(
+ df.select(sha2($"a", 256), sha2("b", 256)),
+ Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78",
+ "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89"))
+
+ checkAnswer(
+ df.selectExpr("sha2(a, 256)", "sha2(b, 256)"),
+ Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78",
+ "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89"))
+
+ intercept[IllegalArgumentException] {
+ df.select(sha2($"a", 1024))
+ }
+ }
+
test("string length function") {
checkAnswer(
nullStrings.select(strlen($"s"), strlen("s")),