From 85f2303ecadd9bf6d9694a2743dda075654c5ccf Mon Sep 17 00:00:00 2001 From: petermaxlee Date: Fri, 1 Jul 2016 07:57:48 +0800 Subject: [SPARK-16276][SQL] Implement elt SQL function ## What changes were proposed in this pull request? This patch implements the elt function, as it is implemented in Hive. ## How was this patch tested? Added expression unit test in StringExpressionsSuite and end-to-end test in StringFunctionsSuite. Author: petermaxlee Closes #13966 from petermaxlee/SPARK-16276. --- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../catalyst/expressions/ExpectsInputTypes.scala | 3 +- .../catalyst/expressions/stringExpressions.scala | 41 ++++++++++++++++++++++ .../expressions/StringExpressionsSuite.scala | 23 ++++++++++++ .../apache/spark/sql/StringFunctionsSuite.scala | 14 ++++++++ .../apache/spark/sql/hive/HiveSessionCatalog.scala | 2 +- 6 files changed, 82 insertions(+), 2 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3fbdb2ab57..26b0c30db4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -267,6 +267,7 @@ object FunctionRegistry { expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), expression[Decode]("decode"), + expression[Elt]("elt"), expression[Encode]("encode"), expression[FindInSet]("find_in_set"), expression[FormatNumber]("format_number"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index c15a2df508..98f25a9ad7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -57,7 +57,8 @@ trait ExpectsInputTypes extends Expression { /** - * A mixin for the analyzer to perform implicit type casting using [[ImplicitTypeCasts]]. + * A mixin for the analyzer to perform implicit type casting using + * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts]]. */ trait ImplicitCastInputTypes extends ExpectsInputTypes { // No other methods diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 44ff7fda8e..b0df957637 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -21,6 +21,7 @@ import java.text.{DecimalFormat, DecimalFormatSymbols} import java.util.{HashMap, Locale, Map => JMap} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ @@ -162,6 +163,46 @@ case class ConcatWs(children: Seq[Expression]) } } +@ExpressionDescription( + usage = "_FUNC_(n, str1, str2, ...) - returns the n-th string, e.g. returns str2 when n is 2", + extended = "> SELECT _FUNC_(1, 'scala', 'java') FROM src LIMIT 1;\n" + "'scala'") +case class Elt(children: Seq[Expression]) + extends Expression with ImplicitCastInputTypes with CodegenFallback { + + private lazy val indexExpr = children.head + private lazy val stringExprs = children.tail.toArray + + /** This expression is always nullable because it returns null if index is out of range. */ + override def nullable: Boolean = true + + override def dataType: DataType = StringType + + override def inputTypes: Seq[DataType] = IntegerType +: Seq.fill(children.size - 1)(StringType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size < 2) { + TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments") + } else { + super[ImplicitCastInputTypes].checkInputDataTypes() + } + } + + override def eval(input: InternalRow): Any = { + val indexObj = indexExpr.eval(input) + if (indexObj == null) { + null + } else { + val index = indexObj.asInstanceOf[Int] + if (index <= 0 || index > stringExprs.length) { + null + } else { + stringExprs(index - 1).eval(input) + } + } + } +} + + trait String2StringExpression extends ImplicitCastInputTypes { self: UnaryExpression => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 29bf15bf52..5f01561986 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -75,6 +75,29 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on } + test("elt") { + def testElt(result: String, n: java.lang.Integer, args: String*): Unit = { + checkEvaluation( + Elt(Literal.create(n, IntegerType) +: args.map(Literal.create(_, StringType))), + result) + } + + testElt("hello", 1, "hello", "world") + testElt(null, 1, null, "world") + testElt(null, null, "hello", "world") + + // Invalid ranages + testElt(null, 3, "hello", "world") + testElt(null, 0, "hello", "world") + testElt(null, -1, "hello", "world") + + // type checking + assert(Elt(Seq.empty).checkInputDataTypes().isFailure) + assert(Elt(Seq(Literal(1))).checkInputDataTypes().isFailure) + assert(Elt(Seq(Literal(1), Literal("A"))).checkInputDataTypes().isSuccess) + assert(Elt(Seq(Literal(1), Literal(2))).checkInputDataTypes().isFailure) + } + test("StringComparison") { val row = create_row("abc", null) val c1 = 'a.string.at(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 1de2d9b5ad..dff4226051 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -48,6 +48,20 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row("a||b")) } + test("string elt") { + val df = Seq[(String, String, String, Int)](("hello", "world", null, 15)) + .toDF("a", "b", "c", "d") + + checkAnswer( + df.selectExpr("elt(0, a, b, c)", "elt(1, a, b, c)", "elt(4, a, b, c)"), + Row(null, "hello", null)) + + // check implicit type cast + checkAnswer( + df.selectExpr("elt(4, a, b, c, d)", "elt('2', a, b, c, d)"), + Row("15", "world")) + } + test("string Levenshtein distance") { val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") checkAnswer(df.select(levenshtein($"l", $"r")), Seq(Row(3), Row(1))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 195591fd9d..1fffadbfca 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -238,7 +238,7 @@ private[sql] class HiveSessionCatalog( // parse_url_tuple, posexplode, reflect2, // str_to_map, windowingtablefunction. private val hiveFunctions = Seq( - "elt", "hash", "java_method", "histogram_numeric", + "hash", "java_method", "histogram_numeric", "map_keys", "map_values", "parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map", "xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long", -- cgit v1.2.3