diff options
author | Dongjoon Hyun <dongjoon@apache.org> | 2016-07-08 17:05:24 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-07-08 17:05:24 +0800 |
commit | a54438cb23c80f7c7fc35da273677c39317cb1a5 (patch) | |
tree | d7f02a31c45eebf00a8a76d7f894dd8468239e1e /sql/catalyst/src/main/scala | |
parent | 8228b06303718b202be60b830df7dfddd97057b1 (diff) | |
download | spark-a54438cb23c80f7c7fc35da273677c39317cb1a5.tar.gz spark-a54438cb23c80f7c7fc35da273677c39317cb1a5.tar.bz2 spark-a54438cb23c80f7c7fc35da273677c39317cb1a5.zip |
[SPARK-16285][SQL] Implement sentences SQL functions
## What changes were proposed in this pull request?
This PR implements `sentences` SQL function.
## How was this patch tested?
Pass the Jenkins tests with a new testcase.
Author: Dongjoon Hyun <dongjoon@apache.org>
Closes #14004 from dongjoon-hyun/SPARK_16285.
Diffstat (limited to 'sql/catalyst/src/main/scala')
2 files changed, 67 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f6ebcaeded..842c9c63ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -296,6 +296,7 @@ object FunctionRegistry { expression[RLike]("rlike"), expression[StringRPad]("rpad"), expression[StringTrimRight]("rtrim"), + expression[Sentences]("sentences"), expression[SoundEx]("soundex"), expression[StringSpace]("space"), expression[StringSplit]("split"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b0df957637..894e12d4a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.catalyst.expressions -import java.text.{DecimalFormat, DecimalFormatSymbols} +import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols} import java.util.{HashMap, Locale, Map => JMap} +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -1188,3 +1190,65 @@ case class FormatNumber(x: Expression, d: Expression) override def prettyName: String = "format_number" } + +/** + * Splits a string into arrays of sentences, where each sentence is an array of words. + * The 'lang' and 'country' arguments are optional, and if omitted, the default locale is used. + */ +@ExpressionDescription( + usage = "_FUNC_(str[, lang, country]) - Splits str into an array of array of words.", + extended = "> SELECT _FUNC_('Hi there! Good morning.');\n [['Hi','there'], ['Good','morning']]") +case class Sentences( + str: Expression, + language: Expression = Literal(""), + country: Expression = Literal("")) + extends Expression with ImplicitCastInputTypes with CodegenFallback { + + def this(str: Expression) = this(str, Literal(""), Literal("")) + def this(str: Expression, language: Expression) = this(str, language, Literal("")) + + override def nullable: Boolean = true + override def dataType: DataType = + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false) + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) + override def children: Seq[Expression] = str :: language :: country :: Nil + + override def eval(input: InternalRow): Any = { + val string = str.eval(input) + if (string == null) { + null + } else { + val languageStr = language.eval(input).asInstanceOf[UTF8String] + val countryStr = country.eval(input).asInstanceOf[UTF8String] + val locale = if (languageStr != null && countryStr != null) { + new Locale(languageStr.toString, countryStr.toString) + } else { + Locale.getDefault + } + getSentences(string.asInstanceOf[UTF8String].toString, locale) + } + } + + private def getSentences(sentences: String, locale: Locale) = { + val bi = BreakIterator.getSentenceInstance(locale) + bi.setText(sentences) + var idx = 0 + val result = new ArrayBuffer[GenericArrayData] + while (bi.next != BreakIterator.DONE) { + val sentence = sentences.substring(idx, bi.current) + idx = bi.current + + val wi = BreakIterator.getWordInstance(locale) + var widx = 0 + wi.setText(sentence) + val words = new ArrayBuffer[UTF8String] + while (wi.next != BreakIterator.DONE) { + val word = sentence.substring(widx, wi.current) + widx = wi.current + if (Character.isLetterOrDigit(word.charAt(0))) words += UTF8String.fromString(word) + } + result += new GenericArrayData(words) + } + new GenericArrayData(result) + } +} |