aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-07-20 22:48:13 -0700
committerReynold Xin <rxin@databricks.com>2015-07-20 22:48:13 -0700
commit67570beed5950974126a91eacd48fd0fedfeb141 (patch)
tree6cb55459b2b8c42abf65dea3d538d9dd3136ca95
parent560b355ccd038ca044726c9c9fcffd14d02e6696 (diff)
downloadspark-67570beed5950974126a91eacd48fd0fedfeb141.tar.gz
spark-67570beed5950974126a91eacd48fd0fedfeb141.tar.bz2
spark-67570beed5950974126a91eacd48fd0fedfeb141.zip
[SPARK-9208][SQL] Remove variant of DataFrame string functions that accept column names.
It can be ambiguous whether that is a string literal or a column name. cc marmbrus Author: Reynold Xin <rxin@databricks.com> Closes #7556 from rxin/str-exprs and squashes the following commits: 92afa83 [Reynold Xin] [SPARK-9208][SQL] Remove variant of DataFrame string functions that accept column names.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala459
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala59
5 files changed, 74 insertions, 455 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index fafdae07c9..9c45b19624 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -684,7 +684,7 @@ object CombineLimits extends Rule[LogicalPlan] {
}
/**
- * Removes the inner [[CaseConversionExpression]] that are unnecessary because
+ * Removes the inner case conversion expressions that are unnecessary because
* the inner conversion is overwritten by the outer one.
*/
object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 41b25d1836..8fa017610b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -69,7 +69,7 @@ object functions {
def column(colName: String): Column = Column(colName)
/**
- * Convert a number from one base to another for the specified expressions
+ * Convert a number in string format from one base to another.
*
* @group math_funcs
* @since 1.5.0
@@ -78,15 +78,6 @@ object functions {
Conv(num.expr, lit(fromBase).expr, lit(toBase).expr)
/**
- * Convert a number from one base to another for the specified expressions
- *
- * @group math_funcs
- * @since 1.5.0
- */
- def conv(numColName: String, fromBase: Int, toBase: Int): Column =
- conv(Column(numColName), fromBase, toBase)
-
- /**
* Creates a [[Column]] of literal value.
*
* The passed in object is returned directly if it is already a [[Column]].
@@ -628,14 +619,6 @@ object functions {
def isNaN(e: Column): Column = IsNaN(e.expr)
/**
- * Converts a string expression to lower case.
- *
- * @group normal_funcs
- * @since 1.3.0
- */
- def lower(e: Column): Column = Lower(e.expr)
-
- /**
* A column expression that generates monotonically increasing 64-bit integers.
*
* The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive.
@@ -792,14 +775,6 @@ object functions {
}
/**
- * Converts a string expression to upper case.
- *
- * @group normal_funcs
- * @since 1.3.0
- */
- def upper(e: Column): Column = Upper(e.expr)
-
- /**
* Computes bitwise NOT.
*
* @group normal_funcs
@@ -1106,9 +1081,8 @@ object functions {
* @since 1.5.0
*/
@scala.annotation.varargs
- def greatest(exprs: Column*): Column = if (exprs.length < 2) {
- sys.error("GREATEST takes at least 2 parameters")
- } else {
+ def greatest(exprs: Column*): Column = {
+ require(exprs.length > 1, "greatest requires at least 2 arguments.")
Greatest(exprs.map(_.expr))
}
@@ -1120,9 +1094,7 @@ object functions {
* @since 1.5.0
*/
@scala.annotation.varargs
- def greatest(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) {
- sys.error("GREATEST takes at least 2 parameters")
- } else {
+ def greatest(columnName: String, columnNames: String*): Column = {
greatest((columnName +: columnNames).map(Column.apply): _*)
}
@@ -1135,14 +1107,6 @@ object functions {
def hex(column: Column): Column = Hex(column.expr)
/**
- * Computes hex value of the given input.
- *
- * @group math_funcs
- * @since 1.5.0
- */
- def hex(colName: String): Column = hex(Column(colName))
-
- /**
* Inverse of hex. Interprets each pair of characters as a hexadecimal number
* and converts to the byte representation of number.
*
@@ -1152,15 +1116,6 @@ object functions {
def unhex(column: Column): Column = Unhex(column.expr)
/**
- * Inverse of hex. Interprets each pair of characters as a hexadecimal number
- * and converts to the byte representation of number.
- *
- * @group math_funcs
- * @since 1.5.0
- */
- def unhex(colName: String): Column = unhex(Column(colName))
-
- /**
* Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.
*
* @group math_funcs
@@ -1233,9 +1188,8 @@ object functions {
* @since 1.5.0
*/
@scala.annotation.varargs
- def least(exprs: Column*): Column = if (exprs.length < 2) {
- sys.error("LEAST takes at least 2 parameters")
- } else {
+ def least(exprs: Column*): Column = {
+ require(exprs.length > 1, "least requires at least 2 arguments.")
Least(exprs.map(_.expr))
}
@@ -1247,9 +1201,7 @@ object functions {
* @since 1.5.0
*/
@scala.annotation.varargs
- def least(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) {
- sys.error("LEAST takes at least 2 parameters")
- } else {
+ def least(columnName: String, columnNames: String*): Column = {
least((columnName +: columnNames).map(Column.apply): _*)
}
@@ -1639,7 +1591,8 @@ object functions {
//////////////////////////////////////////////////////////////////////////////////////////////
/**
- * Calculates the MD5 digest and returns the value as a 32 character hex string.
+ * Calculates the MD5 digest of a binary column and returns the value
+ * as a 32 character hex string.
*
* @group misc_funcs
* @since 1.5.0
@@ -1647,15 +1600,8 @@ object functions {
def md5(e: Column): Column = Md5(e.expr)
/**
- * Calculates the MD5 digest and returns the value as a 32 character hex string.
- *
- * @group misc_funcs
- * @since 1.5.0
- */
- def md5(columnName: String): Column = md5(Column(columnName))
-
- /**
- * Calculates the SHA-1 digest and returns the value as a 40 character hex string.
+ * Calculates the SHA-1 digest of a binary column and returns the value
+ * as a 40 character hex string.
*
* @group misc_funcs
* @since 1.5.0
@@ -1663,15 +1609,11 @@ object functions {
def sha1(e: Column): Column = Sha1(e.expr)
/**
- * Calculates the SHA-1 digest and returns the value as a 40 character hex string.
+ * Calculates the SHA-2 family of hash functions of a binary column and
+ * returns the value as a hex string.
*
- * @group misc_funcs
- * @since 1.5.0
- */
- def sha1(columnName: String): Column = sha1(Column(columnName))
-
- /**
- * Calculates the SHA-2 family of hash functions and returns the value as a hex string.
+ * @param e column to compute SHA-2 on.
+ * @param numBits one of 224, 256, 384, or 512.
*
* @group misc_funcs
* @since 1.5.0
@@ -1683,29 +1625,14 @@ object functions {
}
/**
- * Calculates the SHA-2 family of hash functions and returns the value as a hex string.
- *
- * @group misc_funcs
- * @since 1.5.0
- */
- def sha2(columnName: String, numBits: Int): Column = sha2(Column(columnName), numBits)
-
- /**
- * Calculates the cyclic redundancy check value and returns the value as a bigint.
+ * Calculates the cyclic redundancy check value (CRC32) of a binary column and
+ * returns the value as a bigint.
*
* @group misc_funcs
* @since 1.5.0
*/
def crc32(e: Column): Column = Crc32(e.expr)
- /**
- * Calculates the cyclic redundancy check value and returns the value as a bigint.
- *
- * @group misc_funcs
- * @since 1.5.0
- */
- def crc32(columnName: String): Column = crc32(Column(columnName))
-
//////////////////////////////////////////////////////////////////////////////////////////////
// String functions
//////////////////////////////////////////////////////////////////////////////////////////////
@@ -1720,19 +1647,6 @@ object functions {
def concat(exprs: Column*): Column = Concat(exprs.map(_.expr))
/**
- * Concatenates input strings together into a single string.
- *
- * This is the variant of concat that takes in the column names.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- @scala.annotation.varargs
- def concat(columnName: String, columnNames: String*): Column = {
- concat((columnName +: columnNames).map(Column.apply): _*)
- }
-
- /**
* Concatenates input strings together into a single string, using the given separator.
*
* @group string_funcs
@@ -1744,19 +1658,6 @@ object functions {
}
/**
- * Concatenates input strings together into a single string, using the given separator.
- *
- * This is the variant of concat_ws that takes in the column names.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- @scala.annotation.varargs
- def concat_ws(sep: String, columnName: String, columnNames: String*): Column = {
- concat_ws(sep, (columnName +: columnNames).map(Column.apply) : _*)
- }
-
- /**
* Computes the length of a given string / binary value.
*
* @group string_funcs
@@ -1765,23 +1666,20 @@ object functions {
def length(e: Column): Column = Length(e.expr)
/**
- * Computes the length of a given string / binary column.
+ * Converts a string expression to lower case.
*
* @group string_funcs
- * @since 1.5.0
+ * @since 1.3.0
*/
- def length(columnName: String): Column = length(Column(columnName))
+ def lower(e: Column): Column = Lower(e.expr)
/**
- * Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
- * and returns the result as a string.
- * If d is 0, the result has no decimal point or fractional part.
- * If d < 0, the result will be null.
+ * Converts a string expression to upper case.
*
* @group string_funcs
- * @since 1.5.0
+ * @since 1.3.0
*/
- def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr)
+ def upper(e: Column): Column = Upper(e.expr)
/**
* Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
@@ -1792,43 +1690,25 @@ object functions {
* @group string_funcs
* @since 1.5.0
*/
- def format_number(columnXName: String, d: Int): Column = {
- format_number(Column(columnXName), d)
- }
+ def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr)
/**
- * Computes the Levenshtein distance of the two given strings.
+ * Computes the Levenshtein distance of the two given string columns.
* @group string_funcs
* @since 1.5.0
*/
def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr)
/**
- * Computes the Levenshtein distance of the two given strings.
- * @group string_funcs
- * @since 1.5.0
- */
- def levenshtein(leftColumnName: String, rightColumnName: String): Column =
- levenshtein(Column(leftColumnName), Column(rightColumnName))
-
- /**
- * Computes the numeric value of the first character of the specified string value.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def ascii(e: Column): Column = Ascii(e.expr)
-
- /**
* Computes the numeric value of the first character of the specified string column.
*
* @group string_funcs
* @since 1.5.0
*/
- def ascii(columnName: String): Column = ascii(Column(columnName))
+ def ascii(e: Column): Column = Ascii(e.expr)
/**
- * Trim the spaces from both ends for the specified string value.
+ * Trim the spaces from both ends for the specified string column.
*
* @group string_funcs
* @since 1.5.0
@@ -1836,14 +1716,6 @@ object functions {
def trim(e: Column): Column = StringTrim(e.expr)
/**
- * Trim the spaces from both ends for the specified column.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def trim(columnName: String): Column = trim(Column(columnName))
-
- /**
* Trim the spaces from left end for the specified string value.
*
* @group string_funcs
@@ -1852,14 +1724,6 @@ object functions {
def ltrim(e: Column): Column = StringTrimLeft(e.expr)
/**
- * Trim the spaces from left end for the specified column.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def ltrim(columnName: String): Column = ltrim(Column(columnName))
-
- /**
* Trim the spaces from right end for the specified string value.
*
* @group string_funcs
@@ -1868,25 +1732,6 @@ object functions {
def rtrim(e: Column): Column = StringTrimRight(e.expr)
/**
- * Trim the spaces from right end for the specified column.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def rtrim(columnName: String): Column = rtrim(Column(columnName))
-
- /**
- * Format strings in printf-style.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- @scala.annotation.varargs
- def formatString(format: Column, arguments: Column*): Column = {
- StringFormat((format +: arguments).map(_.expr): _*)
- }
-
- /**
* Format strings in printf-style.
* NOTE: `format` is the string value of the formatter, not column name.
*
@@ -1899,18 +1744,6 @@ object functions {
}
/**
- * Locate the position of the first occurrence of substr value in the given string.
- * Returns null if either of the arguments are null.
- *
- * NOTE: The position is not zero based, but 1 based index, returns 0 if substr
- * could not be found in str.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def instr(substr: String, sub: String): Column = instr(Column(substr), Column(sub))
-
- /**
* Locate the position of the first occurrence of substr column in the given string.
* Returns null if either of the arguments are null.
*
@@ -1920,10 +1753,10 @@ object functions {
* @group string_funcs
* @since 1.5.0
*/
- def instr(substr: Column, sub: Column): Column = StringInstr(substr.expr, sub.expr)
+ def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr)
/**
- * Locate the position of the first occurrence of substr.
+ * Locate the position of the first occurrence of substr in a string column.
*
* NOTE: The position is not zero based, but 1 based index, returns 0 if substr
* could not be found in str.
@@ -1931,77 +1764,26 @@ object functions {
* @group string_funcs
* @since 1.5.0
*/
- def locate(substr: String, str: String): Column = {
- locate(Column(substr), Column(str))
+ def locate(substr: String, str: Column): Column = {
+ new StringLocate(lit(substr).expr, str.expr)
}
/**
- * Locate the position of the first occurrence of substr.
+ * Locate the position of the first occurrence of substr in a string column, after position pos.
*
- * NOTE: The position is not zero based, but 1 based index, returns 0 if substr
+ * NOTE: The position is not zero based, but 1 based index. returns 0 if substr
* could not be found in str.
*
* @group string_funcs
* @since 1.5.0
*/
- def locate(substr: Column, str: Column): Column = {
- new StringLocate(substr.expr, str.expr)
+ def locate(substr: String, str: Column, pos: Int): Column = {
+ StringLocate(lit(substr).expr, str.expr, lit(pos).expr)
}
/**
- * Locate the position of the first occurrence of substr in a given string after position pos.
- *
- * NOTE: The position is not zero based, but 1 based index, returns 0 if substr
- * could not be found in str.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def locate(substr: String, str: String, pos: String): Column = {
- locate(Column(substr), Column(str), Column(pos))
- }
-
- /**
- * Locate the position of the first occurrence of substr in a given string after position pos.
- *
- * NOTE: The position is not zero based, but 1 based index, returns 0 if substr
- * could not be found in str.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def locate(substr: Column, str: Column, pos: Column): Column = {
- StringLocate(substr.expr, str.expr, pos.expr)
- }
-
- /**
- * Locate the position of the first occurrence of substr in a given string after position pos.
- *
- * NOTE: The position is not zero based, but 1 based index, returns 0 if substr
- * could not be found in str.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def locate(substr: Column, str: Column, pos: Int): Column = {
- StringLocate(substr.expr, str.expr, lit(pos).expr)
- }
-
- /**
- * Locate the position of the first occurrence of substr in a given string after position pos.
- *
- * NOTE: The position is not zero based, but 1 based index, returns 0 if substr
- * could not be found in str.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def locate(substr: String, str: String, pos: Int): Column = {
- locate(Column(substr), Column(str), lit(pos))
- }
-
- /**
- * Computes the specified value from binary to a base64 string.
+ * Computes the BASE64 encoding of a binary column and returns it as a string column.
+ * This is the reverse of unbase64.
*
* @group string_funcs
* @since 1.5.0
@@ -2009,15 +1791,8 @@ object functions {
def base64(e: Column): Column = Base64(e.expr)
/**
- * Computes the specified column from binary to a base64 string.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def base64(columnName: String): Column = base64(Column(columnName))
-
- /**
- * Computes the specified value from a base64 string to binary.
+ * Decodes a BASE64 encoded string column and returns it as a binary column.
+ * This is the reverse of base64.
*
* @group string_funcs
* @since 1.5.0
@@ -2025,51 +1800,13 @@ object functions {
def unbase64(e: Column): Column = UnBase64(e.expr)
/**
- * Computes the specified column from a base64 string to binary.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def unbase64(columnName: String): Column = unbase64(Column(columnName))
-
- /**
* Left-padded with pad to a length of len.
*
* @group string_funcs
* @since 1.5.0
*/
- def lpad(str: String, len: String, pad: String): Column = {
- lpad(Column(str), Column(len), Column(pad))
- }
-
- /**
- * Left-padded with pad to a length of len.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def lpad(str: Column, len: Column, pad: Column): Column = {
- StringLPad(str.expr, len.expr, pad.expr)
- }
-
- /**
- * Left-padded with pad to a length of len.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def lpad(str: Column, len: Int, pad: Column): Column = {
- StringLPad(str.expr, lit(len).expr, pad.expr)
- }
-
- /**
- * Left-padded with pad to a length of len.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def lpad(str: String, len: Int, pad: String): Column = {
- lpad(Column(str), len, Column(pad))
+ def lpad(str: Column, len: Int, pad: String): Column = {
+ StringLPad(str.expr, lit(len).expr, lit(pad).expr)
}
/**
@@ -2083,18 +1820,6 @@ object functions {
def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr)
/**
- * Computes the first argument into a binary from a string using the provided character set
- * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
- * If either argument is null, the result will also be null.
- * NOTE: charset represents the string value of the character set, not the column name.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def encode(columnName: String, charset: String): Column =
- encode(Column(columnName), charset)
-
- /**
* Computes the first argument into a string from a binary using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
@@ -2105,105 +1830,23 @@ object functions {
def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr)
/**
- * Computes the first argument into a string from a binary using the provided character set
- * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
- * If either argument is null, the result will also be null.
- * NOTE: charset represents the string value of the character set, not the column name.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def decode(columnName: String, charset: String): Column =
- decode(Column(columnName), charset)
-
- /**
- * Right-padded with pad to a length of len.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def rpad(str: String, len: String, pad: String): Column = {
- rpad(Column(str), Column(len), Column(pad))
- }
-
- /**
- * Right-padded with pad to a length of len.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def rpad(str: Column, len: Column, pad: Column): Column = {
- StringRPad(str.expr, len.expr, pad.expr)
- }
-
- /**
- * Right-padded with pad to a length of len.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def rpad(str: String, len: Int, pad: String): Column = {
- rpad(Column(str), len, Column(pad))
- }
-
- /**
* Right-padded with pad to a length of len.
*
* @group string_funcs
* @since 1.5.0
*/
- def rpad(str: Column, len: Int, pad: Column): Column = {
- StringRPad(str.expr, lit(len).expr, pad.expr)
- }
-
- /**
- * Repeat the string value of the specified column n times.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def repeat(strColumn: String, timesColumn: String): Column = {
- repeat(Column(strColumn), Column(timesColumn))
+ def rpad(str: Column, len: Int, pad: String): Column = {
+ StringRPad(str.expr, lit(len).expr, lit(pad).expr)
}
/**
- * Repeat the string expression value n times.
+ * Repeats a string column n times, and returns it as a new string column.
*
* @group string_funcs
* @since 1.5.0
*/
- def repeat(str: Column, times: Column): Column = {
- StringRepeat(str.expr, times.expr)
- }
-
- /**
- * Repeat the string value of the specified column n times.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def repeat(strColumn: String, times: Int): Column = {
- repeat(Column(strColumn), times)
- }
-
- /**
- * Repeat the string expression value n times.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def repeat(str: Column, times: Int): Column = {
- StringRepeat(str.expr, lit(times).expr)
- }
-
- /**
- * Splits str around pattern (pattern is a regular expression).
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def split(strColumnName: String, pattern: String): Column = {
- split(Column(strColumnName), pattern)
+ def repeat(str: Column, n: Int): Column = {
+ StringRepeat(str.expr, lit(n).expr)
}
/**
@@ -2218,16 +1861,6 @@ object functions {
}
/**
- * Reversed the string for the specified column.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- def reverse(str: String): Column = {
- reverse(Column(str))
- }
-
- /**
* Reversed the string for the specified value.
*
* @group string_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 29f1197a85..8d2ff2f969 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -160,7 +160,7 @@ class DataFrameFunctionsSuite extends QueryTest {
test("misc md5 function") {
val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b")
checkAnswer(
- df.select(md5($"a"), md5("b")),
+ df.select(md5($"a"), md5($"b")),
Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c"))
checkAnswer(
@@ -171,7 +171,7 @@ class DataFrameFunctionsSuite extends QueryTest {
test("misc sha1 function") {
val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b")
checkAnswer(
- df.select(sha1($"a"), sha1("b")),
+ df.select(sha1($"a"), sha1($"b")),
Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8"))
val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b")
@@ -183,7 +183,7 @@ class DataFrameFunctionsSuite extends QueryTest {
test("misc sha2 function") {
val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b")
checkAnswer(
- df.select(sha2($"a", 256), sha2("b", 256)),
+ df.select(sha2($"a", 256), sha2($"b", 256)),
Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78",
"7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89"))
@@ -200,7 +200,7 @@ class DataFrameFunctionsSuite extends QueryTest {
test("misc crc32 function") {
val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b")
checkAnswer(
- df.select(crc32($"a"), crc32("b")),
+ df.select(crc32($"a"), crc32($"b")),
Row(2743272264L, 2180413220L))
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index a51523f1a7..21256704a5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -176,7 +176,6 @@ class MathExpressionsSuite extends QueryTest {
test("conv") {
val df = Seq(("333", 10, 2)).toDF("num", "fromBase", "toBase")
checkAnswer(df.select(conv('num, 10, 16)), Row("14D"))
- checkAnswer(df.select(conv("num", 10, 16)), Row("14D"))
checkAnswer(df.select(conv(lit(100), 2, 16)), Row("4"))
checkAnswer(df.select(conv(lit(3122234455L), 10, 16)), Row("BA198457"))
checkAnswer(df.selectExpr("conv(num, fromBase, toBase)"), Row("101001101"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 413f3858d6..4551192b15 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -52,14 +52,14 @@ class StringFunctionsSuite extends QueryTest {
test("string Levenshtein distance") {
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
- checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1)))
+ checkAnswer(df.select(levenshtein($"l", $"r")), Seq(Row(3), Row(1)))
checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1)))
}
test("string ascii function") {
val df = Seq(("abc", "")).toDF("a", "b")
checkAnswer(
- df.select(ascii($"a"), ascii("b")),
+ df.select(ascii($"a"), ascii($"b")),
Row(97, 0))
checkAnswer(
@@ -71,8 +71,8 @@ class StringFunctionsSuite extends QueryTest {
val bytes = Array[Byte](1, 2, 3, 4)
val df = Seq((bytes, "AQIDBA==")).toDF("a", "b")
checkAnswer(
- df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")),
- Row("AQIDBA==", "AQIDBA==", bytes, bytes))
+ df.select(base64($"a"), unbase64($"b")),
+ Row("AQIDBA==", bytes))
checkAnswer(
df.selectExpr("base64(a)", "unbase64(b)"),
@@ -85,12 +85,8 @@ class StringFunctionsSuite extends QueryTest {
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c")
checkAnswer(
- df.select(
- encode($"a", "utf-8"),
- encode("a", "utf-8"),
- decode($"c", "utf-8"),
- decode("c", "utf-8")),
- Row(bytes, bytes, "大千世界", "大千世界"))
+ df.select(encode($"a", "utf-8"), decode($"c", "utf-8")),
+ Row(bytes, "大千世界"))
checkAnswer(
df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"),
@@ -114,8 +110,8 @@ class StringFunctionsSuite extends QueryTest {
val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c")
checkAnswer(
- df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
- Row("aa123cc", "aa123cc"))
+ df.select(formatString("aa%d%s", "b", "c")),
+ Row("aa123cc"))
checkAnswer(
df.selectExpr("printf(a, b, c)"),
@@ -126,8 +122,8 @@ class StringFunctionsSuite extends QueryTest {
val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c")
checkAnswer(
- df.select(instr($"a", $"b"), instr("a", "b")),
- Row(1, 1))
+ df.select(instr($"a", "aa")),
+ Row(1))
checkAnswer(
df.selectExpr("instr(a, b)"),
@@ -138,10 +134,8 @@ class StringFunctionsSuite extends QueryTest {
val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d")
checkAnswer(
- df.select(
- locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1),
- locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")),
- Row(1, 1, 2, 2, 2, 2))
+ df.select(locate("aa", $"a"), locate("aa", $"a", 1)),
+ Row(1, 2))
checkAnswer(
df.selectExpr("locate(b, a)", "locate(b, a, d)"),
@@ -152,10 +146,8 @@ class StringFunctionsSuite extends QueryTest {
val df = Seq(("hi", 5, "??")).toDF("a", "b", "c")
checkAnswer(
- df.select(
- lpad($"a", $"b", $"c"), rpad("a", "b", "c"),
- lpad($"a", 1, $"c"), rpad("a", 1, "c")),
- Row("???hi", "hi???", "h", "h"))
+ df.select(lpad($"a", 1, "c"), lpad($"a", 5, "??"), rpad($"a", 1, "c"), rpad($"a", 5, "??")),
+ Row("h", "???hi", "h", "hi???"))
checkAnswer(
df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"),
@@ -166,9 +158,8 @@ class StringFunctionsSuite extends QueryTest {
val df = Seq(("hi", 2)).toDF("a", "b")
checkAnswer(
- df.select(
- repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")),
- Row("hihi", "hihi", "hihi", "hihi"))
+ df.select(repeat($"a", 2)),
+ Row("hihi"))
checkAnswer(
df.selectExpr("repeat(a, 2)", "repeat(a, b)"),
@@ -179,7 +170,7 @@ class StringFunctionsSuite extends QueryTest {
val df = Seq(("hi", "hhhi")).toDF("a", "b")
checkAnswer(
- df.select(reverse($"a"), reverse("b")),
+ df.select(reverse($"a"), reverse($"b")),
Row("ih", "ihhh"))
checkAnswer(
@@ -199,10 +190,8 @@ class StringFunctionsSuite extends QueryTest {
val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b")
checkAnswer(
- df.select(
- split($"a", "[1-9]+"),
- split("a", "[1-9]+")),
- Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc")))
+ df.select(split($"a", "[1-9]+")),
+ Row(Seq("aa", "bb", "cc")))
checkAnswer(
df.selectExpr("split(a, '[1-9]+')"),
@@ -212,8 +201,8 @@ class StringFunctionsSuite extends QueryTest {
test("string / binary length function") {
val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c")
checkAnswer(
- df.select(length($"a"), length("a"), length($"b"), length("b")),
- Row(3, 3, 4, 4))
+ df.select(length($"a"), length($"b")),
+ Row(3, 4))
checkAnswer(
df.selectExpr("length(a)", "length(b)"),
@@ -243,10 +232,8 @@ class StringFunctionsSuite extends QueryTest {
"h") // decimal 7.128381
checkAnswer(
- df.select(
- format_number($"f", 4),
- format_number("f", 4)),
- Row("5.0000", "5.0000"))
+ df.select(format_number($"f", 4)),
+ Row("5.0000"))
checkAnswer(
df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer