aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorpetermaxlee <petermaxlee@gmail.com>2016-07-01 07:57:48 +0800
committerWenchen Fan <wenchen@databricks.com>2016-07-01 07:57:48 +0800
commit85f2303ecadd9bf6d9694a2743dda075654c5ccf (patch)
tree09125fdca897a2995d33cf4378f3bfc2ac018ad1 /sql
parent3d75a5b2a76eba0855d73476dc2fd579c612d521 (diff)
downloadspark-85f2303ecadd9bf6d9694a2743dda075654c5ccf.tar.gz
spark-85f2303ecadd9bf6d9694a2743dda075654c5ccf.tar.bz2
spark-85f2303ecadd9bf6d9694a2743dda075654c5ccf.zip
[SPARK-16276][SQL] Implement elt SQL function
## What changes were proposed in this pull request? This patch implements the elt function, as it is implemented in Hive. ## How was this patch tested? Added expression unit test in StringExpressionsSuite and end-to-end test in StringFunctionsSuite. Author: petermaxlee <petermaxlee@gmail.com> Closes #13966 from petermaxlee/SPARK-16276.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala41
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala23
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala14
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala2
6 files changed, 82 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 3fbdb2ab57..26b0c30db4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -267,6 +267,7 @@ object FunctionRegistry {
expression[Concat]("concat"),
expression[ConcatWs]("concat_ws"),
expression[Decode]("decode"),
+ expression[Elt]("elt"),
expression[Encode]("encode"),
expression[FindInSet]("find_in_set"),
expression[FormatNumber]("format_number"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
index c15a2df508..98f25a9ad7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
@@ -57,7 +57,8 @@ trait ExpectsInputTypes extends Expression {
/**
- * A mixin for the analyzer to perform implicit type casting using [[ImplicitTypeCasts]].
+ * A mixin for the analyzer to perform implicit type casting using
+ * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts]].
*/
trait ImplicitCastInputTypes extends ExpectsInputTypes {
// No other methods
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 44ff7fda8e..b0df957637 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -21,6 +21,7 @@ import java.text.{DecimalFormat, DecimalFormatSymbols}
import java.util.{HashMap, Locale, Map => JMap}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._
@@ -162,6 +163,46 @@ case class ConcatWs(children: Seq[Expression])
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(n, str1, str2, ...) - returns the n-th string, e.g. returns str2 when n is 2",
+ extended = "> SELECT _FUNC_(1, 'scala', 'java') FROM src LIMIT 1;\n" + "'scala'")
+case class Elt(children: Seq[Expression])
+ extends Expression with ImplicitCastInputTypes with CodegenFallback {
+
+ private lazy val indexExpr = children.head
+ private lazy val stringExprs = children.tail.toArray
+
+ /** This expression is always nullable because it returns null if index is out of range. */
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = StringType
+
+ override def inputTypes: Seq[DataType] = IntegerType +: Seq.fill(children.size - 1)(StringType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.size < 2) {
+ TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments")
+ } else {
+ super[ImplicitCastInputTypes].checkInputDataTypes()
+ }
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val indexObj = indexExpr.eval(input)
+ if (indexObj == null) {
+ null
+ } else {
+ val index = indexObj.asInstanceOf[Int]
+ if (index <= 0 || index > stringExprs.length) {
+ null
+ } else {
+ stringExprs(index - 1).eval(input)
+ }
+ }
+ }
+}
+
+
trait String2StringExpression extends ImplicitCastInputTypes {
self: UnaryExpression =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 29bf15bf52..5f01561986 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -75,6 +75,29 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// scalastyle:on
}
+ test("elt") {
+ def testElt(result: String, n: java.lang.Integer, args: String*): Unit = {
+ checkEvaluation(
+ Elt(Literal.create(n, IntegerType) +: args.map(Literal.create(_, StringType))),
+ result)
+ }
+
+ testElt("hello", 1, "hello", "world")
+ testElt(null, 1, null, "world")
+ testElt(null, null, "hello", "world")
+
+ // Invalid ranages
+ testElt(null, 3, "hello", "world")
+ testElt(null, 0, "hello", "world")
+ testElt(null, -1, "hello", "world")
+
+ // type checking
+ assert(Elt(Seq.empty).checkInputDataTypes().isFailure)
+ assert(Elt(Seq(Literal(1))).checkInputDataTypes().isFailure)
+ assert(Elt(Seq(Literal(1), Literal("A"))).checkInputDataTypes().isSuccess)
+ assert(Elt(Seq(Literal(1), Literal(2))).checkInputDataTypes().isFailure)
+ }
+
test("StringComparison") {
val row = create_row("abc", null)
val c1 = 'a.string.at(0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 1de2d9b5ad..dff4226051 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -48,6 +48,20 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
Row("a||b"))
}
+ test("string elt") {
+ val df = Seq[(String, String, String, Int)](("hello", "world", null, 15))
+ .toDF("a", "b", "c", "d")
+
+ checkAnswer(
+ df.selectExpr("elt(0, a, b, c)", "elt(1, a, b, c)", "elt(4, a, b, c)"),
+ Row(null, "hello", null))
+
+ // check implicit type cast
+ checkAnswer(
+ df.selectExpr("elt(4, a, b, c, d)", "elt('2', a, b, c, d)"),
+ Row("15", "world"))
+ }
+
test("string Levenshtein distance") {
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
checkAnswer(df.select(levenshtein($"l", $"r")), Seq(Row(3), Row(1)))
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 195591fd9d..1fffadbfca 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -238,7 +238,7 @@ private[sql] class HiveSessionCatalog(
// parse_url_tuple, posexplode, reflect2,
// str_to_map, windowingtablefunction.
private val hiveFunctions = Seq(
- "elt", "hash", "java_method", "histogram_numeric",
+ "hash", "java_method", "histogram_numeric",
"map_keys", "map_values",
"parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map",
"xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",