aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHuJiayin <jiayin.hu@intel.com>2015-08-01 21:44:57 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-01 21:44:57 -0700
commit00cd92f32f17ca57d47aa2dcc716eb707aaee799 (patch)
tree87fae8a2daea19abc7dee69b551c5c0e6f54bf4b
parent5d9e33d9a2633e45082ac395a64646364f22f4c4 (diff)
downloadspark-00cd92f32f17ca57d47aa2dcc716eb707aaee799.tar.gz
spark-00cd92f32f17ca57d47aa2dcc716eb707aaee799.tar.bz2
spark-00cd92f32f17ca57d47aa2dcc716eb707aaee799.zip
[SPARK-8269] [SQL] string function: initcap
This PR is based on #7208 , thanks to HuJiayin Closes #7208 Author: HuJiayin <jiayin.hu@intel.com> Author: Davies Liu <davies@databricks.com> Closes #7850 from davies/initcap and squashes the following commits: 54472e9 [Davies Liu] fix python test 17ffe51 [Davies Liu] Merge branch 'master' of github.com:apache/spark into initcap ca46390 [Davies Liu] Merge branch 'master' of github.com:apache/spark into initcap 3a906e4 [Davies Liu] implement title case in UTF8String 8b2506a [HuJiayin] Update functions.py 2cd43e5 [HuJiayin] fix python style check b616c0e [HuJiayin] add python api 1f5a0ef [HuJiayin] add codegen 7e0c604 [HuJiayin] Merge branch 'master' of https://github.com/apache/spark into initcap 6a0b958 [HuJiayin] add column c79482d [HuJiayin] support soundex 7ce416b [HuJiayin] support initcap rebase code
-rw-r--r--python/pyspark/sql/functions.py12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala17
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala9
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java88
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java8
8 files changed, 156 insertions, 0 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 96975f54ff..a73ecc7d93 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -958,6 +958,18 @@ def substring_index(str, delim, count):
return Column(sc._jvm.functions.substring_index(_to_java_column(str), delim, count))
+@ignore_unicode_prefix
+@since(1.5)
+def initcap(col):
+ """Translate the first letter of each word to upper case in the sentence.
+
+ >>> sqlContext.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).collect()
+ [Row(v=u'Ab Cd')]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.initcap(_to_java_column(col)))
+
+
@since(1.5)
def size(col):
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 6e144518bb..8fafd7778a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -178,6 +178,7 @@ object FunctionRegistry {
expression[Encode]("encode"),
expression[Decode]("decode"),
expression[FormatNumber]("format_number"),
+ expression[InitCap]("initcap"),
expression[Lower]("lcase"),
expression[Lower]("lower"),
expression[Length]("length"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 4d78c55497..80c64e5689 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -598,6 +598,23 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
}
/**
+ * Returns string, with the first letter of each word in uppercase.
+ * Words are delimited by whitespace.
+ */
+case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+
+ override def inputTypes: Seq[DataType] = Seq(StringType)
+ override def dataType: DataType = StringType
+
+ override def nullSafeEval(string: Any): Any = {
+ string.asInstanceOf[UTF8String].toTitleCase
+ }
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ defineCodeGen(ctx, ev, str => s"$str.toTitleCase()")
+ }
+}
+
+/**
* Returns the string which repeat the given string value n times.
*/
case class StringRepeat(str: Expression, times: Expression)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 89c1e33420..906be701be 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -377,6 +377,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null))
}
+ test("initcap unit test") {
+ checkEvaluation(InitCap(Literal.create(null, StringType)), null)
+ checkEvaluation(InitCap(Literal("a b")), "A B")
+ checkEvaluation(InitCap(Literal(" a")), " A")
+ checkEvaluation(InitCap(Literal("the test")), "The Test")
+ // scalastyle:off
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+ checkEvaluation(InitCap(Literal("世界")), "世界")
+ // scalastyle:on
+ }
+
+
test("Levenshtein distance") {
checkEvaluation(Levenshtein(Literal.create(null, StringType), Literal("")), null)
checkEvaluation(Levenshtein(Literal(""), Literal.create(null, StringType)), null)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index babfe21879..818aa109f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1787,6 +1787,15 @@ object functions {
}
/**
+ * Returns string, with the first letter of each word in uppercase.
+ * Words are delimited by whitespace.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def initcap(e: Column): Column = InitCap(e.expr)
+
+ /**
* Locate the position of the first occurrence of substr column in the given string.
* Returns null if either of the arguments are null.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index f40233db0a..1c1be0c3cc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -315,6 +315,15 @@ class StringFunctionsSuite extends QueryTest {
}
}
+ test("initcap function") {
+ val df = Seq(("ab", "a B")).toDF("l", "r")
+ checkAnswer(
+ df.select(initcap($"l"), initcap($"r")), Row("Ab", "A B"))
+
+ checkAnswer(
+ df.selectExpr("InitCap(l)", "InitCap(r)"), Row("Ab", "A B"))
+ }
+
test("number format function") {
val tuple =
("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short],
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 208503d2fd..213dc761bb 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -279,6 +279,29 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
* Returns the upper case of this string
*/
public UTF8String toUpperCase() {
+ if (numBytes == 0) {
+ return EMPTY_UTF8;
+ }
+
+ byte[] bytes = new byte[numBytes];
+ bytes[0] = (byte) Character.toTitleCase(getByte(0));
+ for (int i = 0; i < numBytes; i++) {
+ byte b = getByte(i);
+ if (numBytesForFirstByte(b) != 1) {
+ // fallback
+ return toUpperCaseSlow();
+ }
+ int upper = Character.toUpperCase((int) b);
+ if (upper > 127) {
+ // fallback
+ return toUpperCaseSlow();
+ }
+ bytes[i] = (byte) upper;
+ }
+ return fromBytes(bytes);
+ }
+
+ private UTF8String toUpperCaseSlow() {
return fromString(toString().toUpperCase());
}
@@ -286,10 +309,75 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
* Returns the lower case of this string
*/
public UTF8String toLowerCase() {
+ if (numBytes == 0) {
+ return EMPTY_UTF8;
+ }
+
+ byte[] bytes = new byte[numBytes];
+ bytes[0] = (byte) Character.toTitleCase(getByte(0));
+ for (int i = 0; i < numBytes; i++) {
+ byte b = getByte(i);
+ if (numBytesForFirstByte(b) != 1) {
+ // fallback
+ return toLowerCaseSlow();
+ }
+ int lower = Character.toLowerCase((int) b);
+ if (lower > 127) {
+ // fallback
+ return toLowerCaseSlow();
+ }
+ bytes[i] = (byte) lower;
+ }
+ return fromBytes(bytes);
+ }
+
+ private UTF8String toLowerCaseSlow() {
return fromString(toString().toLowerCase());
}
/**
+ * Returns the title case of this string, that could be used as title.
+ */
+ public UTF8String toTitleCase() {
+ if (numBytes == 0) {
+ return EMPTY_UTF8;
+ }
+
+ byte[] bytes = new byte[numBytes];
+ for (int i = 0; i < numBytes; i++) {
+ byte b = getByte(i);
+ if (i == 0 || getByte(i - 1) == ' ') {
+ if (numBytesForFirstByte(b) != 1) {
+ // fallback
+ return toTitleCaseSlow();
+ }
+ int upper = Character.toTitleCase(b);
+ if (upper > 127) {
+ // fallback
+ return toTitleCaseSlow();
+ }
+ bytes[i] = (byte) upper;
+ } else {
+ bytes[i] = b;
+ }
+ }
+ return fromBytes(bytes);
+ }
+
+ private UTF8String toTitleCaseSlow() {
+ StringBuffer sb = new StringBuffer();
+ String s = toString();
+ sb.append(s);
+ sb.setCharAt(0, Character.toTitleCase(sb.charAt(0)));
+ for (int i = 1; i < s.length(); i++) {
+ if (sb.charAt(i - 1) == ' ') {
+ sb.setCharAt(i, Character.toTitleCase(sb.charAt(i)));
+ }
+ }
+ return fromString(sb.toString());
+ }
+
+ /**
* Copy the bytes from the current UTF8String, and make a new UTF8String.
* @param start the start position of the current UTF8String in bytes.
* @param end the end position of the current UTF8String in bytes.
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index ed50cdcb29..9b3190f8f0 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -115,6 +115,14 @@ public class UTF8StringSuite {
}
@Test
+ public void titleCase() {
+ assertEquals(fromString(""), fromString("").toTitleCase());
+ assertEquals(fromString("Ab Bc Cd"), fromString("ab bc cd").toTitleCase());
+ assertEquals(fromString("Ѐ Ё Ђ Ѻ Ώ Ề"), fromString("ѐ ё ђ ѻ ώ ề").toTitleCase());
+ assertEquals(fromString("大千世界 数据砖头"), fromString("大千世界 数据砖头").toTitleCase());
+ }
+
+ @Test
public void concatTest() {
assertEquals(EMPTY_UTF8, concat());
assertEquals(null, concat((UTF8String) null));