aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/functions.py14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala30
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala12
6 files changed, 81 insertions, 1 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 7d3d036161..45ecd826bd 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -42,6 +42,7 @@ __all__ = [
'monotonicallyIncreasingId',
'rand',
'randn',
+ 'sha1',
'sha2',
'sparkPartitionId',
'struct',
@@ -382,6 +383,19 @@ def sha2(col, numBits):
return Column(jc)
+@ignore_unicode_prefix
+@since(1.5)
+def sha1(col):
+ """Returns the hex string result of SHA-1.
+
+ >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect()
+ [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')]
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.sha1(_to_java_column(col))
+ return Column(jc)
+
+
@since(1.4)
def sparkPartitionId():
"""A column for partition ID of the Spark task.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 457948a800..b24064d061 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -136,6 +136,8 @@ object FunctionRegistry {
// misc functions
expression[Md5]("md5"),
expression[Sha2]("sha2"),
+ expression[Sha1]("sha1"),
+ expression[Sha1]("sha"),
// aggregate functions
expression[Average]("avg"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index e80706fc65..9a39165a1f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -21,8 +21,9 @@ import java.security.MessageDigest
import java.security.NoSuchAlgorithmException
import org.apache.commons.codec.digest.DigestUtils
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType, DataType}
+import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -140,3 +141,30 @@ case class Sha2(left: Expression, right: Expression)
"""
}
}
+
+/**
+ * A function that calculates a sha1 hash value and returns it as a hex string
+ * For input of type [[BinaryType]] or [[StringType]]
+ */
+case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+
+ override def dataType: DataType = StringType
+
+ override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
+
+ override def eval(input: InternalRow): Any = {
+ val value = child.eval(input)
+ if (value == null) {
+ null
+ } else {
+ UTF8String.fromString(DigestUtils.shaHex(value.asInstanceOf[Array[Byte]]))
+ }
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ defineCodeGen(ctx, ev, c =>
+ "org.apache.spark.unsafe.types.UTF8String.fromString" +
+ s"(org.apache.commons.codec.digest.DigestUtils.shaHex($c))"
+ )
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
index 38482c54c6..36e636b5da 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
@@ -31,6 +31,14 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
}
+ test("sha1") {
+ checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")
+ checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
+ "5d211bad8f4ee70e16c7d343a838fc344a1ed961")
+ checkEvaluation(Sha1(Literal.create(null, BinaryType)), null)
+ checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709")
+ }
+
test("sha2") {
checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC"))
checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 355ce0e342..ef92801548 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1415,6 +1415,22 @@ object functions {
def md5(columnName: String): Column = md5(Column(columnName))
/**
+ * Calculates the SHA-1 digest and returns the value as a 40 character hex string.
+ *
+ * @group misc_funcs
+ * @since 1.5.0
+ */
+ def sha1(e: Column): Column = Sha1(e.expr)
+
+ /**
+ * Calculates the SHA-1 digest and returns the value as a 40 character hex string.
+ *
+ * @group misc_funcs
+ * @since 1.5.0
+ */
+ def sha1(columnName: String): Column = sha1(Column(columnName))
+
+ /**
* Calculates the SHA-2 family of hash functions and returns the value as a hex string.
*
* @group misc_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 8baed57a7f..abfd47c811 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -144,6 +144,18 @@ class DataFrameFunctionsSuite extends QueryTest {
Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c"))
}
+ test("misc sha1 function") {
+ val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b")
+ checkAnswer(
+ df.select(sha1($"a"), sha1("b")),
+ Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8"))
+
+ val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b")
+ checkAnswer(
+ dfEmpty.selectExpr("sha1(a)", "sha1(b)"),
+ Row("da39a3ee5e6b4b0d3255bfef95601890afd80709", "da39a3ee5e6b4b0d3255bfef95601890afd80709"))
+ }
+
test("misc sha2 function") {
val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b")
checkAnswer(