aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSandeep Singh <sandeep@techaddict.me>2016-07-22 10:05:21 +0800
committerWenchen Fan <wenchen@databricks.com>2016-07-22 10:05:21 +0800
commitdf2c6d59d0e1a3db9942dbc5e4993cf3babc2d60 (patch)
tree6802b7c39aa8dfb60ee31066939b0eef998183be
parent46f80a307392bee6743e5847eb5243bf5fcd00a4 (diff)
downloadspark-df2c6d59d0e1a3db9942dbc5e4993cf3babc2d60.tar.gz
spark-df2c6d59d0e1a3db9942dbc5e4993cf3babc2d60.tar.bz2
spark-df2c6d59d0e1a3db9942dbc5e4993cf3babc2d60.zip
[SPARK-16287][SQL] Implement str_to_map SQL function
## What changes were proposed in this pull request? This PR adds `str_to_map` SQL function in order to remove Hive fallback. ## How was this patch tested? Pass the Jenkins tests with newly added. Author: Sandeep Singh <sandeep@techaddict.me> Closes #13990 from techaddict/SPARK-16287.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala52
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala36
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala23
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala3
5 files changed, 112 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 65a90d8099..65168998c8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -228,6 +228,7 @@ object FunctionRegistry {
expression[Signum]("signum"),
expression[Sin]("sin"),
expression[Sinh]("sinh"),
+ expression[StringToMap]("str_to_map"),
expression[Sqrt]("sqrt"),
expression[Tan]("tan"),
expression[Tanh]("tanh"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index d603d3c73e..b3c5c585c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -393,3 +393,53 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression
override def prettyName: String = "named_struct_unsafe"
}
+
+/**
+ * Creates a map after splitting the input text into key/value pairs using delimiters
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(text[, pairDelim, keyValueDelim]) - Creates a map after splitting the text " +
+ "into key/value pairs using delimiters. " +
+ "Default delimiters are ',' for pairDelim and ':' for keyValueDelim.",
+ extended = """ > SELECT _FUNC_('a:1,b:2,c:3',',',':');\n map("a":"1","b":"2","c":"3") """)
+case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression)
+ extends TernaryExpression with CodegenFallback with ExpectsInputTypes {
+
+ def this(child: Expression, pairDelim: Expression) = {
+ this(child, pairDelim, Literal(":"))
+ }
+
+ def this(child: Expression) = {
+ this(child, Literal(","), Literal(":"))
+ }
+
+ override def children: Seq[Expression] = Seq(text, pairDelim, keyValueDelim)
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType)
+
+ override def dataType: DataType = MapType(StringType, StringType, valueContainsNull = false)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (Seq(pairDelim, keyValueDelim).exists(! _.foldable)) {
+ TypeCheckResult.TypeCheckFailure(s"$prettyName's delimiters must be foldable.")
+ } else {
+ super.checkInputDataTypes()
+ }
+ }
+
+ override def nullSafeEval(str: Any, delim1: Any, delim2: Any): Any = {
+ val array = str.asInstanceOf[UTF8String]
+ .split(delim1.asInstanceOf[UTF8String], -1)
+ .map { kv =>
+ val arr = kv.split(delim2.asInstanceOf[UTF8String], 2)
+ if (arr.length < 2) {
+ Array(arr(0), null)
+ } else {
+ arr
+ }
+ }
+ ArrayBasedMapData(array.map(_ (0)), array.map(_ (1)))
+ }
+
+ override def prettyName: String = "str_to_map"
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index ec7be4d4b8..0c307b2b85 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -246,4 +246,40 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkMetadata(CreateStructUnsafe(Seq(a, b)))
checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b)))
}
+
+ test("StringToMap") {
+ val s0 = Literal("a:1,b:2,c:3")
+ val m0 = Map("a" -> "1", "b" -> "2", "c" -> "3")
+ checkEvaluation(new StringToMap(s0), m0)
+
+ val s1 = Literal("a: ,b:2")
+ val m1 = Map("a" -> " ", "b" -> "2")
+ checkEvaluation(new StringToMap(s1), m1)
+
+ val s2 = Literal("a=1,b=2,c=3")
+ val m2 = Map("a" -> "1", "b" -> "2", "c" -> "3")
+ checkEvaluation(StringToMap(s2, Literal(","), Literal("=")), m2)
+
+ val s3 = Literal("")
+ val m3 = Map[String, String]("" -> null)
+ checkEvaluation(StringToMap(s3, Literal(","), Literal("=")), m3)
+
+ val s4 = Literal("a:1_b:2_c:3")
+ val m4 = Map("a" -> "1", "b" -> "2", "c" -> "3")
+ checkEvaluation(new StringToMap(s4, Literal("_")), m4)
+
+ // arguments checking
+ assert(new StringToMap(Literal("a:1,b:2,c:3")).checkInputDataTypes().isSuccess)
+ assert(new StringToMap(Literal(null)).checkInputDataTypes().isFailure)
+ assert(new StringToMap(Literal("a:1,b:2,c:3"), Literal(null)).checkInputDataTypes().isFailure)
+ assert(StringToMap(Literal("a:1,b:2,c:3"), Literal(null), Literal(null))
+ .checkInputDataTypes().isFailure)
+ assert(new StringToMap(Literal(null), Literal(null)).checkInputDataTypes().isFailure)
+
+ assert(new StringToMap(Literal("a:1_b:2_c:3"), NonFoldableLiteral("_"))
+ .checkInputDataTypes().isFailure)
+ assert(
+ new StringToMap(Literal("a=1_b=2_c=3"), Literal("_"), NonFoldableLiteral("="))
+ .checkInputDataTypes().isFailure)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index f509551b1e..524926e1e9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -384,4 +384,27 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
}.getMessage
assert(m.contains("Invalid number of arguments for function sentences"))
}
+
+ test("str_to_map function") {
+ val df1 = Seq(
+ ("a=1,b=2", "y"),
+ ("a=1,b=2,c=3", "y")
+ ).toDF("a", "b")
+
+ checkAnswer(
+ df1.selectExpr("str_to_map(a,',','=')"),
+ Seq(
+ Row(Map("a" -> "1", "b" -> "2")),
+ Row(Map("a" -> "1", "b" -> "2", "c" -> "3"))
+ )
+ )
+
+ val df2 = Seq(("a:1,b:2,c:3", "y")).toDF("a", "b")
+
+ checkAnswer(
+ df2.selectExpr("str_to_map(a)"),
+ Seq(Row(Map("a" -> "1", "b" -> "2", "c" -> "3")))
+ )
+
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index b8a75850b1..c59ac3dcaf 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -238,7 +238,6 @@ private[sql] class HiveSessionCatalog(
"hash",
"histogram_numeric",
"percentile",
- "percentile_approx",
- "str_to_map"
+ "percentile_approx"
)
}