diff options
author | Reynold Xin <rxin@databricks.com> | 2015-07-20 22:48:13 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-07-20 22:48:13 -0700 |
commit | 67570beed5950974126a91eacd48fd0fedfeb141 (patch) | |
tree | 6cb55459b2b8c42abf65dea3d538d9dd3136ca95 | |
parent | 560b355ccd038ca044726c9c9fcffd14d02e6696 (diff) | |
download | spark-67570beed5950974126a91eacd48fd0fedfeb141.tar.gz spark-67570beed5950974126a91eacd48fd0fedfeb141.tar.bz2 spark-67570beed5950974126a91eacd48fd0fedfeb141.zip |
[SPARK-9208][SQL] Remove variant of DataFrame string functions that accept column names.
It can be ambiguous whether that is a string literal or a column name.
cc marmbrus
Author: Reynold Xin <rxin@databricks.com>
Closes #7556 from rxin/str-exprs and squashes the following commits:
92afa83 [Reynold Xin] [SPARK-9208][SQL] Remove variant of DataFrame string functions that accept column names.
5 files changed, 74 insertions, 455 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index fafdae07c9..9c45b19624 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -684,7 +684,7 @@ object CombineLimits extends Rule[LogicalPlan] { } /** - * Removes the inner [[CaseConversionExpression]] that are unnecessary because + * Removes the inner case conversion expressions that are unnecessary because * the inner conversion is overwritten by the outer one. */ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 41b25d1836..8fa017610b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -69,7 +69,7 @@ object functions { def column(colName: String): Column = Column(colName) /** - * Convert a number from one base to another for the specified expressions + * Convert a number in string format from one base to another. * * @group math_funcs * @since 1.5.0 @@ -78,15 +78,6 @@ object functions { Conv(num.expr, lit(fromBase).expr, lit(toBase).expr) /** - * Convert a number from one base to another for the specified expressions - * - * @group math_funcs - * @since 1.5.0 - */ - def conv(numColName: String, fromBase: Int, toBase: Int): Column = - conv(Column(numColName), fromBase, toBase) - - /** * Creates a [[Column]] of literal value. * * The passed in object is returned directly if it is already a [[Column]]. @@ -628,14 +619,6 @@ object functions { def isNaN(e: Column): Column = IsNaN(e.expr) /** - * Converts a string expression to lower case. - * - * @group normal_funcs - * @since 1.3.0 - */ - def lower(e: Column): Column = Lower(e.expr) - - /** * A column expression that generates monotonically increasing 64-bit integers. * * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. @@ -792,14 +775,6 @@ object functions { } /** - * Converts a string expression to upper case. - * - * @group normal_funcs - * @since 1.3.0 - */ - def upper(e: Column): Column = Upper(e.expr) - - /** * Computes bitwise NOT. * * @group normal_funcs @@ -1106,9 +1081,8 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(exprs: Column*): Column = if (exprs.length < 2) { - sys.error("GREATEST takes at least 2 parameters") - } else { + def greatest(exprs: Column*): Column = { + require(exprs.length > 1, "greatest requires at least 2 arguments.") Greatest(exprs.map(_.expr)) } @@ -1120,9 +1094,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) { - sys.error("GREATEST takes at least 2 parameters") - } else { + def greatest(columnName: String, columnNames: String*): Column = { greatest((columnName +: columnNames).map(Column.apply): _*) } @@ -1135,14 +1107,6 @@ object functions { def hex(column: Column): Column = Hex(column.expr) /** - * Computes hex value of the given input. - * - * @group math_funcs - * @since 1.5.0 - */ - def hex(colName: String): Column = hex(Column(colName)) - - /** * Inverse of hex. Interprets each pair of characters as a hexadecimal number * and converts to the byte representation of number. * @@ -1152,15 +1116,6 @@ object functions { def unhex(column: Column): Column = Unhex(column.expr) /** - * Inverse of hex. Interprets each pair of characters as a hexadecimal number - * and converts to the byte representation of number. - * - * @group math_funcs - * @since 1.5.0 - */ - def unhex(colName: String): Column = unhex(Column(colName)) - - /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * * @group math_funcs @@ -1233,9 +1188,8 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(exprs: Column*): Column = if (exprs.length < 2) { - sys.error("LEAST takes at least 2 parameters") - } else { + def least(exprs: Column*): Column = { + require(exprs.length > 1, "least requires at least 2 arguments.") Least(exprs.map(_.expr)) } @@ -1247,9 +1201,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) { - sys.error("LEAST takes at least 2 parameters") - } else { + def least(columnName: String, columnNames: String*): Column = { least((columnName +: columnNames).map(Column.apply): _*) } @@ -1639,7 +1591,8 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Calculates the MD5 digest and returns the value as a 32 character hex string. + * Calculates the MD5 digest of a binary column and returns the value + * as a 32 character hex string. * * @group misc_funcs * @since 1.5.0 @@ -1647,15 +1600,8 @@ object functions { def md5(e: Column): Column = Md5(e.expr) /** - * Calculates the MD5 digest and returns the value as a 32 character hex string. - * - * @group misc_funcs - * @since 1.5.0 - */ - def md5(columnName: String): Column = md5(Column(columnName)) - - /** - * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * Calculates the SHA-1 digest of a binary column and returns the value + * as a 40 character hex string. * * @group misc_funcs * @since 1.5.0 @@ -1663,15 +1609,11 @@ object functions { def sha1(e: Column): Column = Sha1(e.expr) /** - * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * Calculates the SHA-2 family of hash functions of a binary column and + * returns the value as a hex string. * - * @group misc_funcs - * @since 1.5.0 - */ - def sha1(columnName: String): Column = sha1(Column(columnName)) - - /** - * Calculates the SHA-2 family of hash functions and returns the value as a hex string. + * @param e column to compute SHA-2 on. + * @param numBits one of 224, 256, 384, or 512. * * @group misc_funcs * @since 1.5.0 @@ -1683,29 +1625,14 @@ object functions { } /** - * Calculates the SHA-2 family of hash functions and returns the value as a hex string. - * - * @group misc_funcs - * @since 1.5.0 - */ - def sha2(columnName: String, numBits: Int): Column = sha2(Column(columnName), numBits) - - /** - * Calculates the cyclic redundancy check value and returns the value as a bigint. + * Calculates the cyclic redundancy check value (CRC32) of a binary column and + * returns the value as a bigint. * * @group misc_funcs * @since 1.5.0 */ def crc32(e: Column): Column = Crc32(e.expr) - /** - * Calculates the cyclic redundancy check value and returns the value as a bigint. - * - * @group misc_funcs - * @since 1.5.0 - */ - def crc32(columnName: String): Column = crc32(Column(columnName)) - ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -1720,19 +1647,6 @@ object functions { def concat(exprs: Column*): Column = Concat(exprs.map(_.expr)) /** - * Concatenates input strings together into a single string. - * - * This is the variant of concat that takes in the column names. - * - * @group string_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def concat(columnName: String, columnNames: String*): Column = { - concat((columnName +: columnNames).map(Column.apply): _*) - } - - /** * Concatenates input strings together into a single string, using the given separator. * * @group string_funcs @@ -1744,19 +1658,6 @@ object functions { } /** - * Concatenates input strings together into a single string, using the given separator. - * - * This is the variant of concat_ws that takes in the column names. - * - * @group string_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def concat_ws(sep: String, columnName: String, columnNames: String*): Column = { - concat_ws(sep, (columnName +: columnNames).map(Column.apply) : _*) - } - - /** * Computes the length of a given string / binary value. * * @group string_funcs @@ -1765,23 +1666,20 @@ object functions { def length(e: Column): Column = Length(e.expr) /** - * Computes the length of a given string / binary column. + * Converts a string expression to lower case. * * @group string_funcs - * @since 1.5.0 + * @since 1.3.0 */ - def length(columnName: String): Column = length(Column(columnName)) + def lower(e: Column): Column = Lower(e.expr) /** - * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, - * and returns the result as a string. - * If d is 0, the result has no decimal point or fractional part. - * If d < 0, the result will be null. + * Converts a string expression to upper case. * * @group string_funcs - * @since 1.5.0 + * @since 1.3.0 */ - def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) + def upper(e: Column): Column = Upper(e.expr) /** * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, @@ -1792,43 +1690,25 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def format_number(columnXName: String, d: Int): Column = { - format_number(Column(columnXName), d) - } + def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) /** - * Computes the Levenshtein distance of the two given strings. + * Computes the Levenshtein distance of the two given string columns. * @group string_funcs * @since 1.5.0 */ def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr) /** - * Computes the Levenshtein distance of the two given strings. - * @group string_funcs - * @since 1.5.0 - */ - def levenshtein(leftColumnName: String, rightColumnName: String): Column = - levenshtein(Column(leftColumnName), Column(rightColumnName)) - - /** - * Computes the numeric value of the first character of the specified string value. - * - * @group string_funcs - * @since 1.5.0 - */ - def ascii(e: Column): Column = Ascii(e.expr) - - /** * Computes the numeric value of the first character of the specified string column. * * @group string_funcs * @since 1.5.0 */ - def ascii(columnName: String): Column = ascii(Column(columnName)) + def ascii(e: Column): Column = Ascii(e.expr) /** - * Trim the spaces from both ends for the specified string value. + * Trim the spaces from both ends for the specified string column. * * @group string_funcs * @since 1.5.0 @@ -1836,14 +1716,6 @@ object functions { def trim(e: Column): Column = StringTrim(e.expr) /** - * Trim the spaces from both ends for the specified column. - * - * @group string_funcs - * @since 1.5.0 - */ - def trim(columnName: String): Column = trim(Column(columnName)) - - /** * Trim the spaces from left end for the specified string value. * * @group string_funcs @@ -1852,14 +1724,6 @@ object functions { def ltrim(e: Column): Column = StringTrimLeft(e.expr) /** - * Trim the spaces from left end for the specified column. - * - * @group string_funcs - * @since 1.5.0 - */ - def ltrim(columnName: String): Column = ltrim(Column(columnName)) - - /** * Trim the spaces from right end for the specified string value. * * @group string_funcs @@ -1868,25 +1732,6 @@ object functions { def rtrim(e: Column): Column = StringTrimRight(e.expr) /** - * Trim the spaces from right end for the specified column. - * - * @group string_funcs - * @since 1.5.0 - */ - def rtrim(columnName: String): Column = rtrim(Column(columnName)) - - /** - * Format strings in printf-style. - * - * @group string_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def formatString(format: Column, arguments: Column*): Column = { - StringFormat((format +: arguments).map(_.expr): _*) - } - - /** * Format strings in printf-style. * NOTE: `format` is the string value of the formatter, not column name. * @@ -1899,18 +1744,6 @@ object functions { } /** - * Locate the position of the first occurrence of substr value in the given string. - * Returns null if either of the arguments are null. - * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr - * could not be found in str. - * - * @group string_funcs - * @since 1.5.0 - */ - def instr(substr: String, sub: String): Column = instr(Column(substr), Column(sub)) - - /** * Locate the position of the first occurrence of substr column in the given string. * Returns null if either of the arguments are null. * @@ -1920,10 +1753,10 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def instr(substr: Column, sub: Column): Column = StringInstr(substr.expr, sub.expr) + def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr) /** - * Locate the position of the first occurrence of substr. + * Locate the position of the first occurrence of substr in a string column. * * NOTE: The position is not zero based, but 1 based index, returns 0 if substr * could not be found in str. @@ -1931,77 +1764,26 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def locate(substr: String, str: String): Column = { - locate(Column(substr), Column(str)) + def locate(substr: String, str: Column): Column = { + new StringLocate(lit(substr).expr, str.expr) } /** - * Locate the position of the first occurrence of substr. + * Locate the position of the first occurrence of substr in a string column, after position pos. * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * NOTE: The position is not zero based, but 1 based index. returns 0 if substr * could not be found in str. * * @group string_funcs * @since 1.5.0 */ - def locate(substr: Column, str: Column): Column = { - new StringLocate(substr.expr, str.expr) + def locate(substr: String, str: Column, pos: Int): Column = { + StringLocate(lit(substr).expr, str.expr, lit(pos).expr) } /** - * Locate the position of the first occurrence of substr in a given string after position pos. - * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr - * could not be found in str. - * - * @group string_funcs - * @since 1.5.0 - */ - def locate(substr: String, str: String, pos: String): Column = { - locate(Column(substr), Column(str), Column(pos)) - } - - /** - * Locate the position of the first occurrence of substr in a given string after position pos. - * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr - * could not be found in str. - * - * @group string_funcs - * @since 1.5.0 - */ - def locate(substr: Column, str: Column, pos: Column): Column = { - StringLocate(substr.expr, str.expr, pos.expr) - } - - /** - * Locate the position of the first occurrence of substr in a given string after position pos. - * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr - * could not be found in str. - * - * @group string_funcs - * @since 1.5.0 - */ - def locate(substr: Column, str: Column, pos: Int): Column = { - StringLocate(substr.expr, str.expr, lit(pos).expr) - } - - /** - * Locate the position of the first occurrence of substr in a given string after position pos. - * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr - * could not be found in str. - * - * @group string_funcs - * @since 1.5.0 - */ - def locate(substr: String, str: String, pos: Int): Column = { - locate(Column(substr), Column(str), lit(pos)) - } - - /** - * Computes the specified value from binary to a base64 string. + * Computes the BASE64 encoding of a binary column and returns it as a string column. + * This is the reverse of unbase64. * * @group string_funcs * @since 1.5.0 @@ -2009,15 +1791,8 @@ object functions { def base64(e: Column): Column = Base64(e.expr) /** - * Computes the specified column from binary to a base64 string. - * - * @group string_funcs - * @since 1.5.0 - */ - def base64(columnName: String): Column = base64(Column(columnName)) - - /** - * Computes the specified value from a base64 string to binary. + * Decodes a BASE64 encoded string column and returns it as a binary column. + * This is the reverse of base64. * * @group string_funcs * @since 1.5.0 @@ -2025,51 +1800,13 @@ object functions { def unbase64(e: Column): Column = UnBase64(e.expr) /** - * Computes the specified column from a base64 string to binary. - * - * @group string_funcs - * @since 1.5.0 - */ - def unbase64(columnName: String): Column = unbase64(Column(columnName)) - - /** * Left-padded with pad to a length of len. * * @group string_funcs * @since 1.5.0 */ - def lpad(str: String, len: String, pad: String): Column = { - lpad(Column(str), Column(len), Column(pad)) - } - - /** - * Left-padded with pad to a length of len. - * - * @group string_funcs - * @since 1.5.0 - */ - def lpad(str: Column, len: Column, pad: Column): Column = { - StringLPad(str.expr, len.expr, pad.expr) - } - - /** - * Left-padded with pad to a length of len. - * - * @group string_funcs - * @since 1.5.0 - */ - def lpad(str: Column, len: Int, pad: Column): Column = { - StringLPad(str.expr, lit(len).expr, pad.expr) - } - - /** - * Left-padded with pad to a length of len. - * - * @group string_funcs - * @since 1.5.0 - */ - def lpad(str: String, len: Int, pad: String): Column = { - lpad(Column(str), len, Column(pad)) + def lpad(str: Column, len: Int, pad: String): Column = { + StringLPad(str.expr, lit(len).expr, lit(pad).expr) } /** @@ -2083,18 +1820,6 @@ object functions { def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr) /** - * Computes the first argument into a binary from a string using the provided character set - * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. - * NOTE: charset represents the string value of the character set, not the column name. - * - * @group string_funcs - * @since 1.5.0 - */ - def encode(columnName: String, charset: String): Column = - encode(Column(columnName), charset) - - /** * Computes the first argument into a string from a binary using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). * If either argument is null, the result will also be null. @@ -2105,105 +1830,23 @@ object functions { def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr) /** - * Computes the first argument into a string from a binary using the provided character set - * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. - * NOTE: charset represents the string value of the character set, not the column name. - * - * @group string_funcs - * @since 1.5.0 - */ - def decode(columnName: String, charset: String): Column = - decode(Column(columnName), charset) - - /** - * Right-padded with pad to a length of len. - * - * @group string_funcs - * @since 1.5.0 - */ - def rpad(str: String, len: String, pad: String): Column = { - rpad(Column(str), Column(len), Column(pad)) - } - - /** - * Right-padded with pad to a length of len. - * - * @group string_funcs - * @since 1.5.0 - */ - def rpad(str: Column, len: Column, pad: Column): Column = { - StringRPad(str.expr, len.expr, pad.expr) - } - - /** - * Right-padded with pad to a length of len. - * - * @group string_funcs - * @since 1.5.0 - */ - def rpad(str: String, len: Int, pad: String): Column = { - rpad(Column(str), len, Column(pad)) - } - - /** * Right-padded with pad to a length of len. * * @group string_funcs * @since 1.5.0 */ - def rpad(str: Column, len: Int, pad: Column): Column = { - StringRPad(str.expr, lit(len).expr, pad.expr) - } - - /** - * Repeat the string value of the specified column n times. - * - * @group string_funcs - * @since 1.5.0 - */ - def repeat(strColumn: String, timesColumn: String): Column = { - repeat(Column(strColumn), Column(timesColumn)) + def rpad(str: Column, len: Int, pad: String): Column = { + StringRPad(str.expr, lit(len).expr, lit(pad).expr) } /** - * Repeat the string expression value n times. + * Repeats a string column n times, and returns it as a new string column. * * @group string_funcs * @since 1.5.0 */ - def repeat(str: Column, times: Column): Column = { - StringRepeat(str.expr, times.expr) - } - - /** - * Repeat the string value of the specified column n times. - * - * @group string_funcs - * @since 1.5.0 - */ - def repeat(strColumn: String, times: Int): Column = { - repeat(Column(strColumn), times) - } - - /** - * Repeat the string expression value n times. - * - * @group string_funcs - * @since 1.5.0 - */ - def repeat(str: Column, times: Int): Column = { - StringRepeat(str.expr, lit(times).expr) - } - - /** - * Splits str around pattern (pattern is a regular expression). - * - * @group string_funcs - * @since 1.5.0 - */ - def split(strColumnName: String, pattern: String): Column = { - split(Column(strColumnName), pattern) + def repeat(str: Column, n: Int): Column = { + StringRepeat(str.expr, lit(n).expr) } /** @@ -2218,16 +1861,6 @@ object functions { } /** - * Reversed the string for the specified column. - * - * @group string_funcs - * @since 1.5.0 - */ - def reverse(str: String): Column = { - reverse(Column(str)) - } - - /** * Reversed the string for the specified value. * * @group string_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 29f1197a85..8d2ff2f969 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -160,7 +160,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("misc md5 function") { val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") checkAnswer( - df.select(md5($"a"), md5("b")), + df.select(md5($"a"), md5($"b")), Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) checkAnswer( @@ -171,7 +171,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("misc sha1 function") { val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b") checkAnswer( - df.select(sha1($"a"), sha1("b")), + df.select(sha1($"a"), sha1($"b")), Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")) val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b") @@ -183,7 +183,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("misc sha2 function") { val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") checkAnswer( - df.select(sha2($"a", 256), sha2("b", 256)), + df.select(sha2($"a", 256), sha2($"b", 256)), Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) @@ -200,7 +200,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("misc crc32 function") { val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") checkAnswer( - df.select(crc32($"a"), crc32("b")), + df.select(crc32($"a"), crc32($"b")), Row(2743272264L, 2180413220L)) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index a51523f1a7..21256704a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -176,7 +176,6 @@ class MathExpressionsSuite extends QueryTest { test("conv") { val df = Seq(("333", 10, 2)).toDF("num", "fromBase", "toBase") checkAnswer(df.select(conv('num, 10, 16)), Row("14D")) - checkAnswer(df.select(conv("num", 10, 16)), Row("14D")) checkAnswer(df.select(conv(lit(100), 2, 16)), Row("4")) checkAnswer(df.select(conv(lit(3122234455L), 10, 16)), Row("BA198457")) checkAnswer(df.selectExpr("conv(num, fromBase, toBase)"), Row("101001101")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 413f3858d6..4551192b15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -52,14 +52,14 @@ class StringFunctionsSuite extends QueryTest { test("string Levenshtein distance") { val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") - checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1))) + checkAnswer(df.select(levenshtein($"l", $"r")), Seq(Row(3), Row(1))) checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1))) } test("string ascii function") { val df = Seq(("abc", "")).toDF("a", "b") checkAnswer( - df.select(ascii($"a"), ascii("b")), + df.select(ascii($"a"), ascii($"b")), Row(97, 0)) checkAnswer( @@ -71,8 +71,8 @@ class StringFunctionsSuite extends QueryTest { val bytes = Array[Byte](1, 2, 3, 4) val df = Seq((bytes, "AQIDBA==")).toDF("a", "b") checkAnswer( - df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")), - Row("AQIDBA==", "AQIDBA==", bytes, bytes)) + df.select(base64($"a"), unbase64($"b")), + Row("AQIDBA==", bytes)) checkAnswer( df.selectExpr("base64(a)", "unbase64(b)"), @@ -85,12 +85,8 @@ class StringFunctionsSuite extends QueryTest { // non ascii characters are not allowed in the code, so we disable the scalastyle here. val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c") checkAnswer( - df.select( - encode($"a", "utf-8"), - encode("a", "utf-8"), - decode($"c", "utf-8"), - decode("c", "utf-8")), - Row(bytes, bytes, "大千世界", "大千世界")) + df.select(encode($"a", "utf-8"), decode($"c", "utf-8")), + Row(bytes, "大千世界")) checkAnswer( df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"), @@ -114,8 +110,8 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") checkAnswer( - df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), - Row("aa123cc", "aa123cc")) + df.select(formatString("aa%d%s", "b", "c")), + Row("aa123cc")) checkAnswer( df.selectExpr("printf(a, b, c)"), @@ -126,8 +122,8 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c") checkAnswer( - df.select(instr($"a", $"b"), instr("a", "b")), - Row(1, 1)) + df.select(instr($"a", "aa")), + Row(1)) checkAnswer( df.selectExpr("instr(a, b)"), @@ -138,10 +134,8 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") checkAnswer( - df.select( - locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1), - locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")), - Row(1, 1, 2, 2, 2, 2)) + df.select(locate("aa", $"a"), locate("aa", $"a", 1)), + Row(1, 2)) checkAnswer( df.selectExpr("locate(b, a)", "locate(b, a, d)"), @@ -152,10 +146,8 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("hi", 5, "??")).toDF("a", "b", "c") checkAnswer( - df.select( - lpad($"a", $"b", $"c"), rpad("a", "b", "c"), - lpad($"a", 1, $"c"), rpad("a", 1, "c")), - Row("???hi", "hi???", "h", "h")) + df.select(lpad($"a", 1, "c"), lpad($"a", 5, "??"), rpad($"a", 1, "c"), rpad($"a", 5, "??")), + Row("h", "???hi", "h", "hi???")) checkAnswer( df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"), @@ -166,9 +158,8 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("hi", 2)).toDF("a", "b") checkAnswer( - df.select( - repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")), - Row("hihi", "hihi", "hihi", "hihi")) + df.select(repeat($"a", 2)), + Row("hihi")) checkAnswer( df.selectExpr("repeat(a, 2)", "repeat(a, b)"), @@ -179,7 +170,7 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("hi", "hhhi")).toDF("a", "b") checkAnswer( - df.select(reverse($"a"), reverse("b")), + df.select(reverse($"a"), reverse($"b")), Row("ih", "ihhh")) checkAnswer( @@ -199,10 +190,8 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") checkAnswer( - df.select( - split($"a", "[1-9]+"), - split("a", "[1-9]+")), - Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc"))) + df.select(split($"a", "[1-9]+")), + Row(Seq("aa", "bb", "cc"))) checkAnswer( df.selectExpr("split(a, '[1-9]+')"), @@ -212,8 +201,8 @@ class StringFunctionsSuite extends QueryTest { test("string / binary length function") { val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") checkAnswer( - df.select(length($"a"), length("a"), length($"b"), length("b")), - Row(3, 3, 4, 4)) + df.select(length($"a"), length($"b")), + Row(3, 4)) checkAnswer( df.selectExpr("length(a)", "length(b)"), @@ -243,10 +232,8 @@ class StringFunctionsSuite extends QueryTest { "h") // decimal 7.128381 checkAnswer( - df.select( - format_number($"f", 4), - format_number("f", 4)), - Row("5.0000", "5.0000")) + df.select(format_number($"f", 4)), + Row("5.0000")) checkAnswer( df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer |