aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2015-07-05 21:50:52 -0700
committerReynold Xin <rxin@databricks.com>2015-07-05 21:50:52 -0700
commit6d0411b4f3a202cfb53f638ee5fd49072b42d3a6 (patch)
tree636bc4585723533528057fea6a04f5aac7fd26b6 /sql
parenta0cb111b22cb093e86b0daeecb3dcc41d095df40 (diff)
downloadspark-6d0411b4f3a202cfb53f638ee5fd49072b42d3a6.tar.gz
spark-6d0411b4f3a202cfb53f638ee5fd49072b42d3a6.tar.bz2
spark-6d0411b4f3a202cfb53f638ee5fd49072b42d3a6.zip
[SQL][Minor] Update the DataFrame API for encode/decode
This is a the follow up of #6843. Author: Cheng Hao <hao.cheng@intel.com> Closes #7230 from chenghao-intel/str_funcs2_followup and squashes the following commits: 52cc553 [Cheng Hao] update the code as comment
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala8
3 files changed, 25 insertions, 18 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 6de40629ff..1a14a7a449 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -392,12 +392,13 @@ case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInput
/**
* Decodes the first argument into a String using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
- * If either argument is null, the result will also be null. (As of Hive 0.12.0.).
+ * If either argument is null, the result will also be null.
*/
-case class Decode(bin: Expression, charset: Expression) extends Expression with ExpectsInputTypes {
- override def children: Seq[Expression] = bin :: charset :: Nil
- override def foldable: Boolean = bin.foldable && charset.foldable
- override def nullable: Boolean = bin.nullable || charset.nullable
+case class Decode(bin: Expression, charset: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ override def left: Expression = bin
+ override def right: Expression = charset
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(BinaryType, StringType)
@@ -420,13 +421,13 @@ case class Decode(bin: Expression, charset: Expression) extends Expression with
/**
* Encodes the first argument into a BINARY using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
- * If either argument is null, the result will also be null. (As of Hive 0.12.0.)
+ * If either argument is null, the result will also be null.
*/
case class Encode(value: Expression, charset: Expression)
- extends Expression with ExpectsInputTypes {
- override def children: Seq[Expression] = value :: charset :: Nil
- override def foldable: Boolean = value.foldable && charset.foldable
- override def nullable: Boolean = value.nullable || charset.nullable
+ extends BinaryExpression with ExpectsInputTypes {
+
+ override def left: Expression = value
+ override def right: Expression = charset
override def dataType: DataType = BinaryType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index abcfc0b650..f80291776f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1666,18 +1666,19 @@ object functions {
* @group string_funcs
* @since 1.5.0
*/
- def encode(value: Column, charset: Column): Column = Encode(value.expr, charset.expr)
+ def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr)
/**
* Computes the first argument into a binary from a string using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
+ * NOTE: charset represents the string value of the character set, not the column name.
*
* @group string_funcs
* @since 1.5.0
*/
- def encode(columnName: String, charsetColumnName: String): Column =
- encode(Column(columnName), Column(charsetColumnName))
+ def encode(columnName: String, charset: String): Column =
+ encode(Column(columnName), charset)
/**
* Computes the first argument into a string from a binary using the provided character set
@@ -1687,18 +1688,19 @@ object functions {
* @group string_funcs
* @since 1.5.0
*/
- def decode(value: Column, charset: Column): Column = Decode(value.expr, charset.expr)
+ def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr)
/**
* Computes the first argument into a string from a binary using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
+ * NOTE: charset represents the string value of the character set, not the column name.
*
* @group string_funcs
* @since 1.5.0
*/
- def decode(columnName: String, charsetColumnName: String): Column =
- decode(Column(columnName), Column(charsetColumnName))
+ def decode(columnName: String, charset: String): Column =
+ decode(Column(columnName), charset)
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index bc455a922d..afba28515e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -261,11 +261,15 @@ class DataFrameFunctionsSuite extends QueryTest {
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c")
checkAnswer(
- df.select(encode($"a", $"b"), encode("a", "b"), decode($"c", $"b"), decode("c", "b")),
+ df.select(
+ encode($"a", "utf-8"),
+ encode("a", "utf-8"),
+ decode($"c", "utf-8"),
+ decode("c", "utf-8")),
Row(bytes, bytes, "大千世界", "大千世界"))
checkAnswer(
- df.selectExpr("encode(a, b)", "decode(c, b)"),
+ df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"),
Row(bytes, "大千世界"))
// scalastyle:on
}