aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCHOIJAEHONG <redrock07@naver.com>2015-09-03 13:38:26 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-09-03 13:38:26 -0700
commitaf0e3125cb1d48b1fc0e44c42b6880d67a9f1a85 (patch)
treeb13448cb7fecabac7323c4d926a6cbd5a6085912
parent3abc0d512541158d11b181e2d9fa126d1371d5c0 (diff)
downloadspark-af0e3125cb1d48b1fc0e44c42b6880d67a9f1a85.tar.gz
spark-af0e3125cb1d48b1fc0e44c42b6880d67a9f1a85.tar.bz2
spark-af0e3125cb1d48b1fc0e44c42b6880d67a9f1a85.zip
[SPARK-8951] [SPARKR] support Unicode characters in collect()
Spark gives an error message and does not show the output when a field of the result DataFrame contains characters in CJK. I changed SerDe.scala in order that Spark support Unicode characters when writes a string to R. Author: CHOIJAEHONG <redrock07@naver.com> Closes #7494 from CHOIJAEHONG1/SPARK-8951.
-rw-r--r--R/pkg/R/deserialize.R6
-rw-r--r--R/pkg/R/serialize.R2
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R26
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/SerDe.scala9
4 files changed, 35 insertions, 8 deletions
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index 6cf628e300..88f18613fd 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -57,8 +57,10 @@ readTypedObject <- function(con, type) {
readString <- function(con) {
stringLen <- readInt(con)
- string <- readBin(con, raw(), stringLen, endian = "big")
- rawToChar(string)
+ raw <- readBin(con, raw(), stringLen, endian = "big")
+ string <- rawToChar(raw)
+ Encoding(string) <- "UTF-8"
+ string
}
readInt <- function(con) {
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
index e3676f57f9..91e6b3e560 100644
--- a/R/pkg/R/serialize.R
+++ b/R/pkg/R/serialize.R
@@ -79,7 +79,7 @@ writeJobj <- function(con, value) {
writeString <- function(con, value) {
utfVal <- enc2utf8(value)
writeInt(con, as.integer(nchar(utfVal, type = "bytes") + 1))
- writeBin(utfVal, con, endian = "big")
+ writeBin(utfVal, con, endian = "big", useBytes=TRUE)
}
writeInt <- function(con, value) {
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 0da5e38654..6d331f9883 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -431,6 +431,32 @@ test_that("collect() and take() on a DataFrame return the same number of rows an
expect_equal(ncol(collect(df)), ncol(take(df, 10)))
})
+test_that("collect() support Unicode characters", {
+ markUtf8 <- function(s) {
+ Encoding(s) <- "UTF-8"
+ s
+ }
+
+ lines <- c("{\"name\":\"안녕하세요\"}",
+ "{\"name\":\"您好\", \"age\":30}",
+ "{\"name\":\"こんにちは\", \"age\":19}",
+ "{\"name\":\"Xin chào\"}")
+
+ jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp")
+ writeLines(lines, jsonPath)
+
+ df <- read.df(sqlContext, jsonPath, "json")
+ rdf <- collect(df)
+ expect_true(is.data.frame(rdf))
+ expect_equal(rdf$name[1], markUtf8("안녕하세요"))
+ expect_equal(rdf$name[2], markUtf8("您好"))
+ expect_equal(rdf$name[3], markUtf8("こんにちは"))
+ expect_equal(rdf$name[4], markUtf8("Xin chào"))
+
+ df1 <- createDataFrame(sqlContext, rdf)
+ expect_equal(collect(where(df1, df1$name == markUtf8("您好")))$name, markUtf8("您好"))
+})
+
test_that("multiple pipeline transformations result in an RDD with the correct values", {
df <- jsonFile(sqlContext, jsonPath)
first <- lapply(df, function(row) {
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 26ad4f1d46..190e193427 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -329,12 +329,11 @@ private[spark] object SerDe {
out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9)
}
- // NOTE: Only works for ASCII right now
def writeString(out: DataOutputStream, value: String): Unit = {
- val len = value.length
- out.writeInt(len + 1) // For the \0
- out.writeBytes(value)
- out.writeByte(0)
+ val utf8 = value.getBytes("UTF-8")
+ val len = utf8.length
+ out.writeInt(len)
+ out.write(utf8, 0, len)
}
def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = {