aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-07-08 17:05:24 +0800
committerWenchen Fan <wenchen@databricks.com>2016-07-08 17:05:24 +0800
commita54438cb23c80f7c7fc35da273677c39317cb1a5 (patch)
treed7f02a31c45eebf00a8a76d7f894dd8468239e1e /sql/catalyst/src/main
parent8228b06303718b202be60b830df7dfddd97057b1 (diff)
downloadspark-a54438cb23c80f7c7fc35da273677c39317cb1a5.tar.gz
spark-a54438cb23c80f7c7fc35da273677c39317cb1a5.tar.bz2
spark-a54438cb23c80f7c7fc35da273677c39317cb1a5.zip
[SPARK-16285][SQL] Implement sentences SQL functions
## What changes were proposed in this pull request? This PR implements `sentences` SQL function. ## How was this patch tested? Pass the Jenkins tests with a new testcase. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #14004 from dongjoon-hyun/SPARK_16285.
Diffstat (limited to 'sql/catalyst/src/main')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala68
2 files changed, 67 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index f6ebcaeded..842c9c63ce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -296,6 +296,7 @@ object FunctionRegistry {
expression[RLike]("rlike"),
expression[StringRPad]("rpad"),
expression[StringTrimRight]("rtrim"),
+ expression[Sentences]("sentences"),
expression[SoundEx]("soundex"),
expression[StringSpace]("space"),
expression[StringSplit]("split"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index b0df957637..894e12d4a3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -17,13 +17,15 @@
package org.apache.spark.sql.catalyst.expressions
-import java.text.{DecimalFormat, DecimalFormatSymbols}
+import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols}
import java.util.{HashMap, Locale, Map => JMap}
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@@ -1188,3 +1190,65 @@ case class FormatNumber(x: Expression, d: Expression)
override def prettyName: String = "format_number"
}
+
+/**
+ * Splits a string into arrays of sentences, where each sentence is an array of words.
+ * The 'lang' and 'country' arguments are optional, and if omitted, the default locale is used.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(str[, lang, country]) - Splits str into an array of array of words.",
+ extended = "> SELECT _FUNC_('Hi there! Good morning.');\n [['Hi','there'], ['Good','morning']]")
+case class Sentences(
+ str: Expression,
+ language: Expression = Literal(""),
+ country: Expression = Literal(""))
+ extends Expression with ImplicitCastInputTypes with CodegenFallback {
+
+ def this(str: Expression) = this(str, Literal(""), Literal(""))
+ def this(str: Expression, language: Expression) = this(str, language, Literal(""))
+
+ override def nullable: Boolean = true
+ override def dataType: DataType =
+ ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType)
+ override def children: Seq[Expression] = str :: language :: country :: Nil
+
+ override def eval(input: InternalRow): Any = {
+ val string = str.eval(input)
+ if (string == null) {
+ null
+ } else {
+ val languageStr = language.eval(input).asInstanceOf[UTF8String]
+ val countryStr = country.eval(input).asInstanceOf[UTF8String]
+ val locale = if (languageStr != null && countryStr != null) {
+ new Locale(languageStr.toString, countryStr.toString)
+ } else {
+ Locale.getDefault
+ }
+ getSentences(string.asInstanceOf[UTF8String].toString, locale)
+ }
+ }
+
+ private def getSentences(sentences: String, locale: Locale) = {
+ val bi = BreakIterator.getSentenceInstance(locale)
+ bi.setText(sentences)
+ var idx = 0
+ val result = new ArrayBuffer[GenericArrayData]
+ while (bi.next != BreakIterator.DONE) {
+ val sentence = sentences.substring(idx, bi.current)
+ idx = bi.current
+
+ val wi = BreakIterator.getWordInstance(locale)
+ var widx = 0
+ wi.setText(sentence)
+ val words = new ArrayBuffer[UTF8String]
+ while (wi.next != BreakIterator.DONE) {
+ val word = sentence.substring(widx, wi.current)
+ widx = wi.current
+ if (Character.isLetterOrDigit(word.charAt(0))) words += UTF8String.fromString(word)
+ }
+ result += new GenericArrayData(words)
+ }
+ new GenericArrayData(result)
+ }
+}