diff options
11 files changed, 156 insertions, 30 deletions
diff --git a/.rat-excludes b/.rat-excludes index 236c2db053..7277146584 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -93,3 +93,4 @@ INDEX .lintr gen-java.* .*avpr +org.apache.spark.sql.sources.DataSourceRegister diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000..cc32d4b727 --- /dev/null +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,3 @@ +org.apache.spark.sql.jdbc.DefaultSource +org.apache.spark.sql.json.DefaultSource +org.apache.spark.sql.parquet.DefaultSource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 0cdb407ad5..8c2f297e42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -17,7 +17,12 @@ package org.apache.spark.sql.execution.datasources +import java.util.ServiceLoader + +import scala.collection.Iterator +import scala.collection.JavaConversions._ import scala.language.{existentials, implicitConversions} +import scala.util.{Failure, Success, Try} import scala.util.matching.Regex import org.apache.hadoop.fs.Path @@ -190,37 +195,32 @@ private[sql] class DDLParser( } } -private[sql] object ResolvedDataSource { - - private val builtinSources = Map( - "jdbc" -> "org.apache.spark.sql.jdbc.DefaultSource", - "json" -> "org.apache.spark.sql.json.DefaultSource", - "parquet" -> "org.apache.spark.sql.parquet.DefaultSource", - "orc" -> "org.apache.spark.sql.hive.orc.DefaultSource" - ) +private[sql] object ResolvedDataSource extends Logging { /** Given a provider name, look up the data source class definition. */ def lookupDataSource(provider: String): Class[_] = { + val provider2 = s"$provider.DefaultSource" val loader = Utils.getContextOrSparkClassLoader - - if (builtinSources.contains(provider)) { - return loader.loadClass(builtinSources(provider)) - } - - try { - loader.loadClass(provider) - } catch { - case cnf: java.lang.ClassNotFoundException => - try { - loader.loadClass(provider + ".DefaultSource") - } catch { - case cnf: java.lang.ClassNotFoundException => - if (provider.startsWith("org.apache.spark.sql.hive.orc")) { - sys.error("The ORC data source must be used with Hive support enabled.") - } else { - sys.error(s"Failed to load class for data source: $provider") - } + val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) + + serviceLoader.iterator().filter(_.format().equalsIgnoreCase(provider)).toList match { + /** the provider format did not match any given registered aliases */ + case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => dataSource + case Failure(error) => if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + throw new ClassNotFoundException( + "The ORC data source must be used with Hive support enabled.", error) + } else { + throw new ClassNotFoundException( + s"Failed to load class for data source: $provider", error) } + } + /** there is exactly one registered alias */ + case head :: Nil => head.getClass + /** There are multiple registered aliases for the input */ + case sources => sys.error(s"Multiple sources found for $provider, " + + s"(${sources.map(_.getClass.getName).mkString(", ")}), " + + "please specify the fully qualified class name") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 41d0ecb4bb..48d97ced9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -77,7 +77,10 @@ private[sql] object JDBCRelation { } } -private[sql] class DefaultSource extends RelationProvider { +private[sql] class DefaultSource extends RelationProvider with DataSourceRegister { + + def format(): String = "jdbc" + /** Returns a new base relation with the given parameters. */ override def createRelation( sqlContext: SQLContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 10f1367e69..b34a272ec5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -37,7 +37,10 @@ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{AnalysisException, Row, SQLContext} -private[sql] class DefaultSource extends HadoopFsRelationProvider { +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + def format(): String = "json" + override def createRelation( sqlContext: SQLContext, paths: Array[String], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 48009b2fd0..b6db71b5b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -49,7 +49,10 @@ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} -private[sql] class DefaultSource extends HadoopFsRelationProvider { +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + def format(): String = "parquet" + override def createRelation( sqlContext: SQLContext, paths: Array[String], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index c5b7ee73eb..4aafec0e2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -39,6 +39,27 @@ import org.apache.spark.util.SerializableConfiguration /** * ::DeveloperApi:: + * Data sources should implement this trait so that they can register an alias to their data source. + * This allows users to give the data source alias as the format type over the fully qualified + * class name. + * + * ex: parquet.DefaultSource.format = "parquet". + * + * A new instance of this class with be instantiated each time a DDL call is made. + */ +@DeveloperApi +trait DataSourceRegister { + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source, + * ex: override def format(): String = "parquet" + */ + def format(): String +} + +/** + * ::DeveloperApi:: * Implemented by objects that produce relations for a specific kind of data source. When * Spark SQL is given a DDL operation with a USING clause specified (to specify the implemented * RelationProvider), this interface is used to pass in the parameters specified by a user. diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000..cfd7889b4a --- /dev/null +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,3 @@ +org.apache.spark.sql.sources.FakeSourceOne +org.apache.spark.sql.sources.FakeSourceTwo +org.apache.spark.sql.sources.FakeSourceThree diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala new file mode 100644 index 0000000000..1a4d41b02c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -0,0 +1,85 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +class FakeSourceOne extends RelationProvider with DataSourceRegister { + + def format(): String = "Fluet da Bomb" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} + +class FakeSourceTwo extends RelationProvider with DataSourceRegister { + + def format(): String = "Fluet da Bomb" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} + +class FakeSourceThree extends RelationProvider with DataSourceRegister { + + def format(): String = "gathering quorum" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} +// please note that the META-INF/services had to be modified for the test directory for this to work +class DDLSourceLoadSuite extends DataSourceTest { + + test("data sources with the same name") { + intercept[RuntimeException] { + caseInsensitiveContext.read.format("Fluet da Bomb").load() + } + } + + test("load data source from format alias") { + caseInsensitiveContext.read.format("gathering quorum").load().schema == + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("specify full classname with duplicate formats") { + caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne") + .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) + } + + test("Loading Orc") { + intercept[ClassNotFoundException] { + caseInsensitiveContext.read.format("orc").load() + } + } +} diff --git a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000..4a774fbf1f --- /dev/null +++ b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.hive.orc.DefaultSource diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 7c8704b47f..0c344c63fd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -47,7 +47,10 @@ import org.apache.spark.util.SerializableConfiguration /* Implicit conversions */ import scala.collection.JavaConversions._ -private[sql] class DefaultSource extends HadoopFsRelationProvider { +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + def format(): String = "orc" + def createRelation( sqlContext: SQLContext, paths: Array[String], |