aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-07-08 17:05:24 +0800
committerWenchen Fan <wenchen@databricks.com>2016-07-08 17:05:24 +0800
commita54438cb23c80f7c7fc35da273677c39317cb1a5 (patch)
treed7f02a31c45eebf00a8a76d7f894dd8468239e1e
parent8228b06303718b202be60b830df7dfddd97057b1 (diff)
downloadspark-a54438cb23c80f7c7fc35da273677c39317cb1a5.tar.gz
spark-a54438cb23c80f7c7fc35da273677c39317cb1a5.tar.bz2
spark-a54438cb23c80f7c7fc35da273677c39317cb1a5.zip
[SPARK-16285][SQL] Implement sentences SQL functions
## What changes were proposed in this pull request? This PR implements `sentences` SQL function. ## How was this patch tested? Pass the Jenkins tests with a new testcase. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #14004 from dongjoon-hyun/SPARK_16285.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala68
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala23
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala20
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala2
5 files changed, 111 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index f6ebcaeded..842c9c63ce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -296,6 +296,7 @@ object FunctionRegistry {
expression[RLike]("rlike"),
expression[StringRPad]("rpad"),
expression[StringTrimRight]("rtrim"),
+ expression[Sentences]("sentences"),
expression[SoundEx]("soundex"),
expression[StringSpace]("space"),
expression[StringSplit]("split"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index b0df957637..894e12d4a3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -17,13 +17,15 @@
package org.apache.spark.sql.catalyst.expressions
-import java.text.{DecimalFormat, DecimalFormatSymbols}
+import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols}
import java.util.{HashMap, Locale, Map => JMap}
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@@ -1188,3 +1190,65 @@ case class FormatNumber(x: Expression, d: Expression)
override def prettyName: String = "format_number"
}
+
+/**
+ * Splits a string into arrays of sentences, where each sentence is an array of words.
+ * The 'lang' and 'country' arguments are optional, and if omitted, the default locale is used.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(str[, lang, country]) - Splits str into an array of array of words.",
+ extended = "> SELECT _FUNC_('Hi there! Good morning.');\n [['Hi','there'], ['Good','morning']]")
+case class Sentences(
+ str: Expression,
+ language: Expression = Literal(""),
+ country: Expression = Literal(""))
+ extends Expression with ImplicitCastInputTypes with CodegenFallback {
+
+ def this(str: Expression) = this(str, Literal(""), Literal(""))
+ def this(str: Expression, language: Expression) = this(str, language, Literal(""))
+
+ override def nullable: Boolean = true
+ override def dataType: DataType =
+ ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType)
+ override def children: Seq[Expression] = str :: language :: country :: Nil
+
+ override def eval(input: InternalRow): Any = {
+ val string = str.eval(input)
+ if (string == null) {
+ null
+ } else {
+ val languageStr = language.eval(input).asInstanceOf[UTF8String]
+ val countryStr = country.eval(input).asInstanceOf[UTF8String]
+ val locale = if (languageStr != null && countryStr != null) {
+ new Locale(languageStr.toString, countryStr.toString)
+ } else {
+ Locale.getDefault
+ }
+ getSentences(string.asInstanceOf[UTF8String].toString, locale)
+ }
+ }
+
+ private def getSentences(sentences: String, locale: Locale) = {
+ val bi = BreakIterator.getSentenceInstance(locale)
+ bi.setText(sentences)
+ var idx = 0
+ val result = new ArrayBuffer[GenericArrayData]
+ while (bi.next != BreakIterator.DONE) {
+ val sentence = sentences.substring(idx, bi.current)
+ idx = bi.current
+
+ val wi = BreakIterator.getWordInstance(locale)
+ var widx = 0
+ wi.setText(sentence)
+ val words = new ArrayBuffer[UTF8String]
+ while (wi.next != BreakIterator.DONE) {
+ val word = sentence.substring(widx, wi.current)
+ widx = wi.current
+ if (Character.isLetterOrDigit(word.charAt(0))) words += UTF8String.fromString(word)
+ }
+ result += new GenericArrayData(words)
+ }
+ new GenericArrayData(result)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 5f01561986..256ce85743 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -725,4 +725,27 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(FindInSet(Literal("abf"), Literal("abc,b,ab,c,def")), 0)
checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0)
}
+
+ test("Sentences") {
+ val nullString = Literal.create(null, StringType)
+ checkEvaluation(Sentences(nullString, nullString, nullString), null)
+ checkEvaluation(Sentences(nullString, nullString), null)
+ checkEvaluation(Sentences(nullString), null)
+ checkEvaluation(Sentences(Literal.create(null, NullType)), null)
+ checkEvaluation(Sentences("", nullString, nullString), Seq.empty)
+ checkEvaluation(Sentences("", nullString), Seq.empty)
+ checkEvaluation(Sentences(""), Seq.empty)
+
+ val answer = Seq(
+ Seq("Hi", "there"),
+ Seq("The", "price", "was"),
+ Seq("But", "not", "now"))
+
+ checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now."), answer)
+ checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "en"), answer)
+ checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "en", "US"),
+ answer)
+ checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "XXX", "YYY"),
+ answer)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 3edd988496..044ac22328 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -349,4 +349,24 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
df2.filter("b>0").selectExpr("format_number(a, b)"),
Row("5.0000") :: Row("4.000") :: Row("4.000") :: Row("4.000") :: Row("3.00") :: Nil)
}
+
+ test("string sentences function") {
+ val df = Seq(("Hi there! The price was $1,234.56.... But, not now.", "en", "US"))
+ .toDF("str", "language", "country")
+
+ checkAnswer(
+ df.selectExpr("sentences(str, language, country)"),
+ Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now"))))
+
+ // Type coercion
+ checkAnswer(
+ df.selectExpr("sentences(null)", "sentences(10)", "sentences(3.14)"),
+ Row(null, Seq(Seq("10")), Seq(Seq("3.14"))))
+
+ // Argument number exception
+ val m = intercept[AnalysisException] {
+ df.selectExpr("sentences()")
+ }.getMessage
+ assert(m.contains("Invalid number of arguments for function sentences"))
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index fdc4c18e70..6f05f0f305 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -236,7 +236,7 @@ private[sql] class HiveSessionCatalog(
// str_to_map, windowingtablefunction.
private val hiveFunctions = Seq(
"hash", "java_method", "histogram_numeric",
- "parse_url", "percentile", "percentile_approx", "reflect", "sentences", "str_to_map",
+ "parse_url", "percentile", "percentile_approx", "reflect", "str_to_map",
"xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
"xpath_number", "xpath_short", "xpath_string"
)