From a54438cb23c80f7c7fc35da273677c39317cb1a5 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 8 Jul 2016 17:05:24 +0800 Subject: [SPARK-16285][SQL] Implement sentences SQL functions ## What changes were proposed in this pull request? This PR implements `sentences` SQL function. ## How was this patch tested? Pass the Jenkins tests with a new testcase. Author: Dongjoon Hyun Closes #14004 from dongjoon-hyun/SPARK_16285. --- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../catalyst/expressions/stringExpressions.scala | 68 +++++++++++++++++++++- .../expressions/StringExpressionsSuite.scala | 23 ++++++++ .../apache/spark/sql/StringFunctionsSuite.scala | 20 +++++++ .../apache/spark/sql/hive/HiveSessionCatalog.scala | 2 +- 5 files changed, 111 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f6ebcaeded..842c9c63ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -296,6 +296,7 @@ object FunctionRegistry { expression[RLike]("rlike"), expression[StringRPad]("rpad"), expression[StringTrimRight]("rtrim"), + expression[Sentences]("sentences"), expression[SoundEx]("soundex"), expression[StringSpace]("space"), expression[StringSplit]("split"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b0df957637..894e12d4a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.catalyst.expressions -import java.text.{DecimalFormat, DecimalFormatSymbols} +import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols} import java.util.{HashMap, Locale, Map => JMap} +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -1188,3 +1190,65 @@ case class FormatNumber(x: Expression, d: Expression) override def prettyName: String = "format_number" } + +/** + * Splits a string into arrays of sentences, where each sentence is an array of words. + * The 'lang' and 'country' arguments are optional, and if omitted, the default locale is used. + */ +@ExpressionDescription( + usage = "_FUNC_(str[, lang, country]) - Splits str into an array of array of words.", + extended = "> SELECT _FUNC_('Hi there! Good morning.');\n [['Hi','there'], ['Good','morning']]") +case class Sentences( + str: Expression, + language: Expression = Literal(""), + country: Expression = Literal("")) + extends Expression with ImplicitCastInputTypes with CodegenFallback { + + def this(str: Expression) = this(str, Literal(""), Literal("")) + def this(str: Expression, language: Expression) = this(str, language, Literal("")) + + override def nullable: Boolean = true + override def dataType: DataType = + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false) + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) + override def children: Seq[Expression] = str :: language :: country :: Nil + + override def eval(input: InternalRow): Any = { + val string = str.eval(input) + if (string == null) { + null + } else { + val languageStr = language.eval(input).asInstanceOf[UTF8String] + val countryStr = country.eval(input).asInstanceOf[UTF8String] + val locale = if (languageStr != null && countryStr != null) { + new Locale(languageStr.toString, countryStr.toString) + } else { + Locale.getDefault + } + getSentences(string.asInstanceOf[UTF8String].toString, locale) + } + } + + private def getSentences(sentences: String, locale: Locale) = { + val bi = BreakIterator.getSentenceInstance(locale) + bi.setText(sentences) + var idx = 0 + val result = new ArrayBuffer[GenericArrayData] + while (bi.next != BreakIterator.DONE) { + val sentence = sentences.substring(idx, bi.current) + idx = bi.current + + val wi = BreakIterator.getWordInstance(locale) + var widx = 0 + wi.setText(sentence) + val words = new ArrayBuffer[UTF8String] + while (wi.next != BreakIterator.DONE) { + val word = sentence.substring(widx, wi.current) + widx = wi.current + if (Character.isLetterOrDigit(word.charAt(0))) words += UTF8String.fromString(word) + } + result += new GenericArrayData(words) + } + new GenericArrayData(result) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 5f01561986..256ce85743 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -725,4 +725,27 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(FindInSet(Literal("abf"), Literal("abc,b,ab,c,def")), 0) checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0) } + + test("Sentences") { + val nullString = Literal.create(null, StringType) + checkEvaluation(Sentences(nullString, nullString, nullString), null) + checkEvaluation(Sentences(nullString, nullString), null) + checkEvaluation(Sentences(nullString), null) + checkEvaluation(Sentences(Literal.create(null, NullType)), null) + checkEvaluation(Sentences("", nullString, nullString), Seq.empty) + checkEvaluation(Sentences("", nullString), Seq.empty) + checkEvaluation(Sentences(""), Seq.empty) + + val answer = Seq( + Seq("Hi", "there"), + Seq("The", "price", "was"), + Seq("But", "not", "now")) + + checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now."), answer) + checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "en"), answer) + checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "en", "US"), + answer) + checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "XXX", "YYY"), + answer) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 3edd988496..044ac22328 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -349,4 +349,24 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { df2.filter("b>0").selectExpr("format_number(a, b)"), Row("5.0000") :: Row("4.000") :: Row("4.000") :: Row("4.000") :: Row("3.00") :: Nil) } + + test("string sentences function") { + val df = Seq(("Hi there! The price was $1,234.56.... But, not now.", "en", "US")) + .toDF("str", "language", "country") + + checkAnswer( + df.selectExpr("sentences(str, language, country)"), + Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now")))) + + // Type coercion + checkAnswer( + df.selectExpr("sentences(null)", "sentences(10)", "sentences(3.14)"), + Row(null, Seq(Seq("10")), Seq(Seq("3.14")))) + + // Argument number exception + val m = intercept[AnalysisException] { + df.selectExpr("sentences()") + }.getMessage + assert(m.contains("Invalid number of arguments for function sentences")) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index fdc4c18e70..6f05f0f305 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -236,7 +236,7 @@ private[sql] class HiveSessionCatalog( // str_to_map, windowingtablefunction. private val hiveFunctions = Seq( "hash", "java_method", "histogram_numeric", - "parse_url", "percentile", "percentile_approx", "reflect", "sentences", "str_to_map", + "parse_url", "percentile", "percentile_approx", "reflect", "str_to_map", "xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long", "xpath_number", "xpath_short", "xpath_string" ) -- cgit v1.2.3