aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-10-23 13:04:06 -0700
committerYin Huai <yhuai@databricks.com>2015-10-23 13:04:06 -0700
commite1a897b657eb62e837026f7b3efafb9a6424ec4f (patch)
tree6c9553d281c65342db5ff692cbe3f4814a539eed
parent4e38defae13b2b13e196b4d172722ef5e6266c66 (diff)
downloadspark-e1a897b657eb62e837026f7b3efafb9a6424ec4f.tar.gz
spark-e1a897b657eb62e837026f7b3efafb9a6424ec4f.tar.bz2
spark-e1a897b657eb62e837026f7b3efafb9a6424ec4f.zip
[SPARK-11274] [SQL] Text data source support for Spark SQL.
This adds API for reading and writing text files, similar to SparkContext.textFile and RDD.saveAsTextFile. ``` SQLContext.read.text("/path/to/something.txt") DataFrame.write.text("/path/to/write.txt") ``` Using the new Dataset API, this also supports ``` val ds: Dataset[String] = SQLContext.read.text("/path/to/something.txt").as[String] ``` Author: Reynold Xin <rxin@databricks.com> Closes #9240 from rxin/SPARK-11274.
-rw-r--r--sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala160
-rw-r--r--sql/core/src/test/resources/text-suite.txt4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala81
7 files changed, 283 insertions, 4 deletions
diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index ca50000b47..1ca2044057 100644
--- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -1,3 +1,4 @@
org.apache.spark.sql.execution.datasources.jdbc.DefaultSource
org.apache.spark.sql.execution.datasources.json.DefaultSource
org.apache.spark.sql.execution.datasources.parquet.DefaultSource
+org.apache.spark.sql.execution.datasources.text.DefaultSource
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index e8651a3569..824220d85e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -302,6 +302,22 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
DataFrame(sqlContext, sqlContext.catalog.lookupRelation(TableIdentifier(tableName)))
}
+ /**
+ * Loads a text file and returns a [[DataFrame]] with a single string column named "text".
+ * Each line in the text file is a new row in the resulting DataFrame. For example:
+ * {{{
+ * // Scala:
+ * sqlContext.read.text("/path/to/spark/README.md")
+ *
+ * // Java:
+ * sqlContext.read().text("/path/to/spark/README.md")
+ * }}}
+ *
+ * @param path input path
+ * @since 1.6.0
+ */
+ def text(path: String): DataFrame = format("text").load(path)
+
///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 764510ab4b..7887e559a3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -244,6 +244,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @param connectionProperties JDBC database connection arguments, a list of arbitrary string
* tag/value. Normally at least a "user" and "password" property
* should be included.
+ *
+ * @since 1.4.0
*/
def jdbc(url: String, table: String, connectionProperties: Properties): Unit = {
val props = new Properties()
@@ -317,6 +319,22 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*/
def orc(path: String): Unit = format("orc").save(path)
+ /**
+ * Saves the content of the [[DataFrame]] in a text file at the specified path.
+ * The DataFrame must have only one column that is of string type.
+ * Each row becomes a new line in the output file. For example:
+ * {{{
+ * // Scala:
+ * df.write.text("/path/to/output")
+ *
+ * // Java:
+ * df.write().text("/path/to/output")
+ * }}}
+ *
+ * @since 1.6.0
+ */
+ def text(path: String): Unit = format("text").save(path)
+
///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
index d05e6efa83..794b889a93 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
@@ -161,11 +161,10 @@ private[json] class JsonOutputWriter(
context: TaskAttemptContext)
extends OutputWriter with SparkHadoopMapRedUtil with Logging {
- val writer = new CharArrayWriter()
+ private[this] val writer = new CharArrayWriter()
// create the Generator without separator inserted between 2 records
- val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
-
- val result = new Text()
+ private[this] val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
+ private[this] val result = new Text()
private val recordWriter: RecordWriter[NullWritable, Text] = {
new TextOutputFormat[NullWritable, Text]() {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
new file mode 100644
index 0000000000..ab26c57ad1
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.text
+
+import com.google.common.base.Objects
+import org.apache.hadoop.fs.{Path, FileStatus}
+import org.apache.hadoop.io.{NullWritable, Text, LongWritable}
+import org.apache.hadoop.mapred.{TextInputFormat, JobConf}
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
+import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, Job}
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
+
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.mapred.SparkHadoopMapRedUtil
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
+import org.apache.spark.sql.execution.datasources.PartitionSpec
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types.{StringType, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A data source for reading text files.
+ */
+class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister {
+
+ override def createRelation(
+ sqlContext: SQLContext,
+ paths: Array[String],
+ dataSchema: Option[StructType],
+ partitionColumns: Option[StructType],
+ parameters: Map[String, String]): HadoopFsRelation = {
+ dataSchema.foreach(verifySchema)
+ new TextRelation(None, partitionColumns, paths)(sqlContext)
+ }
+
+ override def shortName(): String = "text"
+
+ private def verifySchema(schema: StructType): Unit = {
+ if (schema.size != 1) {
+ throw new AnalysisException(
+ s"Text data source supports only a single column, and you have ${schema.size} columns.")
+ }
+ val tpe = schema(0).dataType
+ if (tpe != StringType) {
+ throw new AnalysisException(
+ s"Text data source supports only a string column, but you have ${tpe.simpleString}.")
+ }
+ }
+}
+
+private[sql] class TextRelation(
+ val maybePartitionSpec: Option[PartitionSpec],
+ override val userDefinedPartitionColumns: Option[StructType],
+ override val paths: Array[String] = Array.empty[String])
+ (@transient val sqlContext: SQLContext)
+ extends HadoopFsRelation(maybePartitionSpec) {
+
+ /** Data schema is always a single column, named "text". */
+ override def dataSchema: StructType = new StructType().add("text", StringType)
+
+ /** This is an internal data source that outputs internal row format. */
+ override val needConversion: Boolean = false
+
+ /** Read path. */
+ override def buildScan(inputPaths: Array[FileStatus]): RDD[Row] = {
+ val job = new Job(sqlContext.sparkContext.hadoopConfiguration)
+ val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job)
+ val paths = inputPaths.map(_.getPath).sortBy(_.toUri)
+
+ if (paths.nonEmpty) {
+ FileInputFormat.setInputPaths(job, paths: _*)
+ }
+
+ sqlContext.sparkContext.hadoopRDD(
+ conf.asInstanceOf[JobConf], classOf[TextInputFormat], classOf[LongWritable], classOf[Text])
+ .mapPartitions { iter =>
+ var buffer = new Array[Byte](1024)
+ val row = new GenericMutableRow(1)
+ iter.map { case (_, line) =>
+ if (line.getLength > buffer.length) {
+ buffer = new Array[Byte](line.getLength)
+ }
+ System.arraycopy(line.getBytes, 0, buffer, 0, line.getLength)
+ row.update(0, UTF8String.fromBytes(buffer, 0, line.getLength))
+ row
+ }
+ }.asInstanceOf[RDD[Row]]
+ }
+
+ /** Write path. */
+ override def prepareJobForWrite(job: Job): OutputWriterFactory = {
+ new OutputWriterFactory {
+ override def newInstance(
+ path: String,
+ dataSchema: StructType,
+ context: TaskAttemptContext): OutputWriter = {
+ new TextOutputWriter(path, dataSchema, context)
+ }
+ }
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case that: TextRelation =>
+ paths.toSet == that.paths.toSet && partitionColumns == that.partitionColumns
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ Objects.hashCode(paths.toSet, partitionColumns)
+ }
+}
+
+class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext)
+ extends OutputWriter
+ with SparkHadoopMapRedUtil {
+
+ private[this] val buffer = new Text()
+
+ private val recordWriter: RecordWriter[NullWritable, Text] = {
+ new TextOutputFormat[NullWritable, Text]() {
+ override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
+ val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context)
+ val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID")
+ val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context)
+ val split = taskAttemptId.getTaskID.getId
+ new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension")
+ }
+ }.getRecordWriter(context)
+ }
+
+ override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")
+
+ override protected[sql] def writeInternal(row: InternalRow): Unit = {
+ val utf8string = row.getUTF8String(0)
+ buffer.set(utf8string.getBytes)
+ recordWriter.write(NullWritable.get(), buffer)
+ }
+
+ override def close(): Unit = {
+ recordWriter.close(context)
+ }
+}
diff --git a/sql/core/src/test/resources/text-suite.txt b/sql/core/src/test/resources/text-suite.txt
new file mode 100644
index 0000000000..e8fd967197
--- /dev/null
+++ b/sql/core/src/test/resources/text-suite.txt
@@ -0,0 +1,4 @@
+This is a test file for the text data source
+1+1
+数据砖头
+"doh"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
new file mode 100644
index 0000000000..0a2306c066
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.text
+
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{StringType, StructType}
+import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
+import org.apache.spark.util.Utils
+
+
+class TextSuite extends QueryTest with SharedSQLContext {
+
+ test("reading text file") {
+ verifyFrame(sqlContext.read.format("text").load(testFile))
+ }
+
+ test("SQLContext.read.text() API") {
+ verifyFrame(sqlContext.read.text(testFile))
+ }
+
+ test("writing") {
+ val df = sqlContext.read.text(testFile)
+
+ val tempFile = Utils.createTempDir()
+ tempFile.delete()
+ df.write.text(tempFile.getCanonicalPath)
+ verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath))
+
+ Utils.deleteRecursively(tempFile)
+ }
+
+ test("error handling for invalid schema") {
+ val tempFile = Utils.createTempDir()
+ tempFile.delete()
+
+ val df = sqlContext.range(2)
+ intercept[AnalysisException] {
+ df.write.text(tempFile.getCanonicalPath)
+ }
+
+ intercept[AnalysisException] {
+ sqlContext.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath)
+ }
+ }
+
+ private def testFile: String = {
+ Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString
+ }
+
+ /** Verifies data and schema. */
+ private def verifyFrame(df: DataFrame): Unit = {
+ // schema
+ assert(df.schema == new StructType().add("text", StringType))
+
+ // verify content
+ val data = df.collect()
+ assert(data(0) == Row("This is a test file for the text data source"))
+ assert(data(1) == Row("1+1"))
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+ // scalastyle:off
+ assert(data(2) == Row("数据砖头"))
+ // scalastyle:on
+ assert(data(3) == Row("\"doh\""))
+ assert(data.length == 4)
+ }
+}