From b8f88d32723eaea4807c10b5b79d0c76f30b0510 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 10 Feb 2015 20:40:21 -0800 Subject: [SPARK-5702][SQL] Allow short names for built-in data sources. Also took the chance to fixed up some style ... Author: Reynold Xin Closes #4489 from rxin/SPARK-5702 and squashes the following commits: 74f42e3 [Reynold Xin] [SPARK-5702][SQL] Allow short names for built-in data sources. --- .../org/apache/spark/sql/jdbc/JDBCRelation.scala | 26 ++++---- .../org/apache/spark/sql/json/JSONRelation.scala | 1 + .../scala/org/apache/spark/sql/sources/ddl.scala | 77 ++++++++++++---------- .../sql/sources/ResolvedDataSourceSuite.scala | 34 ++++++++++ 4 files changed, 90 insertions(+), 48 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 66ad38eb7c..beb76f2c55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -48,11 +48,6 @@ private[sql] object JDBCRelation { * exactly once. The parameters minValue and maxValue are advisory in that * incorrect values may cause the partitioning to be poor, but no data * will fail to be represented. - * - * @param column - Column name. Must refer to a column of integral type. - * @param numPartitions - Number of partitions - * @param minValue - Smallest value of column. Advisory. - * @param maxValue - Largest value of column. Advisory. */ def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { if (partitioning == null) return Array[Partition](JDBCPartition(null, 0)) @@ -68,12 +63,17 @@ private[sql] object JDBCRelation { var currentValue: Long = partitioning.lowerBound var ans = new ArrayBuffer[Partition]() while (i < numPartitions) { - val lowerBound = (if (i != 0) s"$column >= $currentValue" else null) + val lowerBound = if (i != 0) s"$column >= $currentValue" else null currentValue += stride - val upperBound = (if (i != numPartitions - 1) s"$column < $currentValue" else null) - val whereClause = (if (upperBound == null) lowerBound - else if (lowerBound == null) upperBound - else s"$lowerBound AND $upperBound") + val upperBound = if (i != numPartitions - 1) s"$column < $currentValue" else null + val whereClause = + if (upperBound == null) { + lowerBound + } else if (lowerBound == null) { + upperBound + } else { + s"$lowerBound AND $upperBound" + } ans += JDBCPartition(whereClause, i) i = i + 1 } @@ -96,8 +96,7 @@ private[sql] class DefaultSource extends RelationProvider { if (driver != null) Class.forName(driver) - if ( - partitionColumn != null + if (partitionColumn != null && (lowerBound == null || upperBound == null || numPartitions == null)) { sys.error("Partitioning incompletely specified") } @@ -119,7 +118,8 @@ private[sql] class DefaultSource extends RelationProvider { private[sql] case class JDBCRelation( url: String, table: String, - parts: Array[Partition])(@transient val sqlContext: SQLContext) extends PrunedFilteredScan { + parts: Array[Partition])(@transient val sqlContext: SQLContext) + extends PrunedFilteredScan { override val schema = JDBCRDD.resolveTable(url, table) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index f828bcdd65..51ff2443f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.json import java.io.IOException import org.apache.hadoop.fs.Path + import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 6487c14b1e..d3d72089c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -234,65 +234,73 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { primitiveType } -object ResolvedDataSource { - def apply( - sqlContext: SQLContext, - userSpecifiedSchema: Option[StructType], - provider: String, - options: Map[String, String]): ResolvedDataSource = { +private[sql] object ResolvedDataSource { + + private val builtinSources = Map( + "jdbc" -> classOf[org.apache.spark.sql.jdbc.DefaultSource], + "json" -> classOf[org.apache.spark.sql.json.DefaultSource], + "parquet" -> classOf[org.apache.spark.sql.parquet.DefaultSource] + ) + + /** Given a provider name, look up the data source class definition. */ + def lookupDataSource(provider: String): Class[_] = { + if (builtinSources.contains(provider)) { + return builtinSources(provider) + } + val loader = Utils.getContextOrSparkClassLoader - val clazz: Class[_] = try loader.loadClass(provider) catch { + try { + loader.loadClass(provider) + } catch { case cnf: java.lang.ClassNotFoundException => - try loader.loadClass(provider + ".DefaultSource") catch { + try { + loader.loadClass(provider + ".DefaultSource") + } catch { case cnf: java.lang.ClassNotFoundException => sys.error(s"Failed to load class for data source: $provider") } } + } + /** Create a [[ResolvedDataSource]] for reading data in. */ + def apply( + sqlContext: SQLContext, + userSpecifiedSchema: Option[StructType], + provider: String, + options: Map[String, String]): ResolvedDataSource = { + val clazz: Class[_] = lookupDataSource(provider) val relation = userSpecifiedSchema match { - case Some(schema: StructType) => { - clazz.newInstance match { - case dataSource: SchemaRelationProvider => - dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) - case dataSource: org.apache.spark.sql.sources.RelationProvider => - sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.") - } + case Some(schema: StructType) => clazz.newInstance() match { + case dataSource: SchemaRelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) + case dataSource: org.apache.spark.sql.sources.RelationProvider => + sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.") } - case None => { - clazz.newInstance match { - case dataSource: RelationProvider => - dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.") - } + + case None => clazz.newInstance() match { + case dataSource: RelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) + case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => + sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.") } } - new ResolvedDataSource(clazz, relation) } + /** Create a [[ResolvedDataSource]] for saving the content of the given [[DataFrame]]. */ def apply( sqlContext: SQLContext, provider: String, mode: SaveMode, options: Map[String, String], data: DataFrame): ResolvedDataSource = { - val loader = Utils.getContextOrSparkClassLoader - val clazz: Class[_] = try loader.loadClass(provider) catch { - case cnf: java.lang.ClassNotFoundException => - try loader.loadClass(provider + ".DefaultSource") catch { - case cnf: java.lang.ClassNotFoundException => - sys.error(s"Failed to load class for data source: $provider") - } - } - - val relation = clazz.newInstance match { + val clazz: Class[_] = lookupDataSource(provider) + val relation = clazz.newInstance() match { case dataSource: CreatableRelationProvider => dataSource.createRelation(sqlContext, mode, options, data) case _ => sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") } - new ResolvedDataSource(clazz, relation) } } @@ -405,6 +413,5 @@ protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[St /** * The exception thrown from the DDL parser. - * @param message */ protected[sql] class DDLException(message: String) extends Exception(message) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala new file mode 100644 index 0000000000..8331a14c92 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -0,0 +1,34 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.sources + +import org.scalatest.FunSuite + +class ResolvedDataSourceSuite extends FunSuite { + + test("builtin sources") { + assert(ResolvedDataSource.lookupDataSource("jdbc") === + classOf[org.apache.spark.sql.jdbc.DefaultSource]) + + assert(ResolvedDataSource.lookupDataSource("json") === + classOf[org.apache.spark.sql.json.DefaultSource]) + + assert(ResolvedDataSource.lookupDataSource("parquet") === + classOf[org.apache.spark.sql.parquet.DefaultSource]) + } +} -- cgit v1.2.3