aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/functions.py19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala98
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala17
6 files changed, 165 insertions, 4 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index cfa87aeea1..7d3d036161 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -42,6 +42,7 @@ __all__ = [
'monotonicallyIncreasingId',
'rand',
'randn',
+ 'sha2',
'sparkPartitionId',
'struct',
'udf',
@@ -363,6 +364,24 @@ def randn(seed=None):
return Column(jc)
+@ignore_unicode_prefix
+@since(1.5)
+def sha2(col, numBits):
+ """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384,
+ and SHA-512). The numBits indicates the desired bit length of the result, which must have a
+ value of 224, 256, 384, 512, or 0 (which is equivalent to 256).
+
+ >>> digests = df.select(sha2(df.name, 256).alias('s')).collect()
+ >>> digests[0]
+ Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043')
+ >>> digests[1]
+ Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961')
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.sha2(_to_java_column(col), numBits)
+ return Column(jc)
+
+
@since(1.4)
def sparkPartitionId():
"""A column for partition ID of the Spark task.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 5fb3369f85..457948a800 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -135,6 +135,7 @@ object FunctionRegistry {
// misc functions
expression[Md5]("md5"),
+ expression[Sha2]("sha2"),
// aggregate functions
expression[Average]("avg"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 4bee8cb728..e80706fc65 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -17,9 +17,12 @@
package org.apache.spark.sql.catalyst.expressions
+import java.security.MessageDigest
+import java.security.NoSuchAlgorithmException
+
import org.apache.commons.codec.digest.DigestUtils
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.types.{BinaryType, StringType, DataType}
+import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType, DataType}
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -44,7 +47,96 @@ case class Md5(child: Expression)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c =>
- "org.apache.spark.unsafe.types.UTF8String.fromString" +
- s"(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))")
+ s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))")
+ }
+}
+
+/**
+ * A function that calculates the SHA-2 family of functions (SHA-224, SHA-256, SHA-384, and SHA-512)
+ * and returns it as a hex string. The first argument is the string or binary to be hashed. The
+ * second argument indicates the desired bit length of the result, which must have a value of 224,
+ * 256, 384, 512, or 0 (which is equivalent to 256). SHA-224 is supported starting from Java 8. If
+ * asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or
+ * the hash length is not one of the permitted values, the return value is NULL.
+ */
+case class Sha2(left: Expression, right: Expression)
+ extends BinaryExpression with Serializable with ExpectsInputTypes {
+
+ override def dataType: DataType = StringType
+
+ override def toString: String = s"SHA2($left, $right)"
+
+ override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
+
+ override def eval(input: InternalRow): Any = {
+ val evalE1 = left.eval(input)
+ if (evalE1 == null) {
+ null
+ } else {
+ val evalE2 = right.eval(input)
+ if (evalE2 == null) {
+ null
+ } else {
+ val bitLength = evalE2.asInstanceOf[Int]
+ val input = evalE1.asInstanceOf[Array[Byte]]
+ bitLength match {
+ case 224 =>
+ // DigestUtils doesn't support SHA-224 now
+ try {
+ val md = MessageDigest.getInstance("SHA-224")
+ md.update(input)
+ UTF8String.fromBytes(md.digest())
+ } catch {
+ // SHA-224 is not supported on the system, return null
+ case noa: NoSuchAlgorithmException => null
+ }
+ case 256 | 0 =>
+ UTF8String.fromString(DigestUtils.sha256Hex(input))
+ case 384 =>
+ UTF8String.fromString(DigestUtils.sha384Hex(input))
+ case 512 =>
+ UTF8String.fromString(DigestUtils.sha512Hex(input))
+ case _ => null
+ }
+ }
+ }
+ }
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val eval1 = left.gen(ctx)
+ val eval2 = right.gen(ctx)
+ val digestUtils = "org.apache.commons.codec.digest.DigestUtils"
+
+ s"""
+ ${eval1.code}
+ boolean ${ev.isNull} = ${eval1.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${eval2.code}
+ if (!${eval2.isNull}) {
+ if (${eval2.primitive} == 224) {
+ try {
+ java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224");
+ md.update(${eval1.primitive});
+ ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest());
+ } catch (java.security.NoSuchAlgorithmException e) {
+ ${ev.isNull} = true;
+ }
+ } else if (${eval2.primitive} == 256 || ${eval2.primitive} == 0) {
+ ${ev.primitive} =
+ ${ctx.stringType}.fromString(${digestUtils}.sha256Hex(${eval1.primitive}));
+ } else if (${eval2.primitive} == 384) {
+ ${ev.primitive} =
+ ${ctx.stringType}.fromString(${digestUtils}.sha384Hex(${eval1.primitive}));
+ } else if (${eval2.primitive} == 512) {
+ ${ev.primitive} =
+ ${ctx.stringType}.fromString(${digestUtils}.sha512Hex(${eval1.primitive}));
+ } else {
+ ${ev.isNull} = true;
+ }
+ } else {
+ ${ev.isNull} = true;
+ }
+ }
+ """
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
index 48b84130b4..38482c54c6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
@@ -17,8 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.commons.codec.digest.DigestUtils
+
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.types.{StringType, BinaryType}
+import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType}
class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -29,4 +31,14 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
}
+ test("sha2") {
+ checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC"))
+ checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
+ DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6)))
+ // unsupported bit length
+ checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null)
+ checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null)
+ checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null)
+ checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 38d9085a50..355ce0e342 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1414,6 +1414,26 @@ object functions {
*/
def md5(columnName: String): Column = md5(Column(columnName))
+ /**
+ * Calculates the SHA-2 family of hash functions and returns the value as a hex string.
+ *
+ * @group misc_funcs
+ * @since 1.5.0
+ */
+ def sha2(e: Column, numBits: Int): Column = {
+ require(Seq(0, 224, 256, 384, 512).contains(numBits),
+ s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)")
+ Sha2(e.expr, lit(numBits).expr)
+ }
+
+ /**
+ * Calculates the SHA-2 family of hash functions and returns the value as a hex string.
+ *
+ * @group misc_funcs
+ * @since 1.5.0
+ */
+ def sha2(columnName: String, numBits: Int): Column = sha2(Column(columnName), numBits)
+
//////////////////////////////////////////////////////////////////////////////////////////////
// String functions
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 8b53b384a2..8baed57a7f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -144,6 +144,23 @@ class DataFrameFunctionsSuite extends QueryTest {
Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c"))
}
+ test("misc sha2 function") {
+ val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b")
+ checkAnswer(
+ df.select(sha2($"a", 256), sha2("b", 256)),
+ Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78",
+ "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89"))
+
+ checkAnswer(
+ df.selectExpr("sha2(a, 256)", "sha2(b, 256)"),
+ Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78",
+ "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89"))
+
+ intercept[IllegalArgumentException] {
+ df.select(sha2($"a", 1024))
+ }
+ }
+
test("string length function") {
checkAnswer(
nullStrings.select(strlen($"s"), strlen("s")),